Programming for AI (Lab) -Week 6
Instructor: Dr. Natasha Nigar
Topic: Matplotlib
1. Introduction to Matplotlib
Matplotlib is a powerful Python library for creating static, animated, and interactive
visualizations. It is widely used in data analysis, machine learning, and scientific research.
Key Features:
Simple syntax similar to MATLAB
Customizable plots
Support for multiple backends
Integration with NumPy and Pandas
2. Installation and Setup
To install Matplotlib, use the following command:
pip install matplotlib
To import the library:
import matplotlib.pyplot as plt
import numpy as np
3. Basic Plotting Functions
3.1 Line Plot
x = np.linspace(0, 10, 100)
y = np.sin(x)
plt.plot(x, y, label='Sine Wave')
plt.xlabel('X Axis')
plt.ylabel('Y Axis')
plt.title('Line Plot Example')
plt.legend()
plt.show()
3.2 Scatter Plot
x = np.random.rand(50)
y = np.random.rand(50)
plt.scatter(x, y, color='r', marker='o')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Scatter Plot')
plt.show()
3.3 Bar Chart
categories = ['A', 'B', 'C', 'D']
values = [3, 7, 1, 8]
plt.bar(categories, values, color='blue')
plt.title('Bar Chart')
plt.show()
4. Customization and Styling
Changing Line Styles and Colors
plt.plot(x, y, linestyle='--', color='red', linewidth=2)
plt.show()
Adding Grid and Annotations
plt.plot(x, y)
plt.grid(True)
plt.annotate('Peak', xy=(np.pi/2, 1), xytext=(2, 0.8), arrowprops=dict(facecolor='black'))
plt.show()
5. Advanced Plotting Techniques
Subplots
fig, axs = plt.subplots(2, 2)
axs[0, 0].plot(x, y)
axs[0, 1].scatter(x, y)
axs[1, 0].bar(categories, values)
axs[1, 1].hist(np.random.randn(100), bins=10)
plt.show()
Heatmaps
import seaborn as sns
data = np.random.rand(10, 10)
sns.heatmap(data, cmap='coolwarm')
plt.show()
6. Working with Multiple Plots
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(x, y, 'g')
plt.subplot(1, 2, 2)
plt.scatter(x, y, color='b')
plt.show()
7. 3D Plotting with Matplotlib
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
Z = np.sin(np.sqrt(X**2 + Y**2))
ax.plot_surface(X, Y, Z, cmap='viridis')
plt.show()
8. Case Studies
Case Study 1: Sales Data Visualization
Problem: Analyzing monthly sales performance.
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May']
sales = [300, 450, 150, 400, 550]
plt.plot(months, sales, marker='o', linestyle='-', color='blue')
plt.title('Monthly Sales')
plt.xlabel('Month')
plt.ylabel('Sales ($)')
plt.show()
Case Study 2: Weather Data Analysis
Problem: Visualizing temperature variations.
days = np.arange(1, 31)
temps = np.random.normal(25, 5, 30)
plt.plot(days, temps, marker='s', linestyle='-', color='red')
plt.fill_between(days, temps - 2, temps + 2, alpha=0.2)
plt.title('Temperature Variation')
plt.xlabel('Days')
plt.ylabel('Temperature (°C)')
plt.show()
Case Study 3: Stock Market Trends
Problem: Comparing stock price trends.
days = np.arange(1, 31)
stock1 = np.cumsum(np.random.randn(30)) + 100
stock2 = np.cumsum(np.random.randn(30)) + 120
plt.plot(days, stock1, label='Stock A', color='green')
plt.plot(days, stock2, label='Stock B', color='blue')
plt.legend()
plt.title('Stock Market Trends')
plt.xlabel('Days')
plt.ylabel('Stock Price ($)')
plt.show()