Neural networks are composed of many connected neurons organized into hierarchical layers. The connections between neurons have a weight $w_i$ associated with them that changes through learning. The output of a neuron is computed by taking the weighted sum of all the inputs $x_i$ and a bias term $b$ feeding into the neuron and passing it through a non-linear activation function $\sigma$.
$$ y = \sigma (\sum_{i} {w_i x_i} + b) $$The activation function plays a very important role in neural networks. Injecting non-linearity through these activation functions enables the ability for a network to model complex relationships.
Many activation functions have been used over the years with the most popular in early neural network research being sigmoids like the logistic function or the hyperbolic tangent.
%matplotlib inline
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
matplotlib.style.use('seaborn-whitegrid')
x = torch.linspace(-10, 10, 100)
plt.plot(x, torch.sigmoid(x), label="sigmoid")
plt.plot(x, torch.tanh(x), label="tanh")
plt.legend()
C:\Users\tmwht\AppData\Local\Temp\ipykernel_7960\1935553181.py:9: MatplotlibDeprecationWarning: The seaborn styles shipped by Matplotlib are deprecated since 3.6, as they no longer correspond to the styles shipped by seaborn. However, they will remain available as 'seaborn-v0_8-<style>'. Alternatively, directly use the seaborn API instead. matplotlib.style.use('seaborn-whitegrid')
<matplotlib.legend.Legend at 0x159e8290ee0>
And we can plot the derivatives using torch's autograd.
def plot_derivative(x, func):
f = func(x)
out = f.sum()
out.backward()
plt.plot(x.detach(), x.grad.detach(), label=f"{func.__name__}")
plt.legend()
x.grad.zero_()
x = torch.linspace(-10, 10, 100, requires_grad=True)
plot_derivative(x, torch.sigmoid)
plot_derivative(x, torch.tanh)
Sigmoidal functions are great for shallow networks, however they tend to suffer from the vanishing gradient problem with deeper networks.
Modern deep learning models mostly use rectifiers which have an unbounded and (mostly) constant positive range.
with torch.no_grad():
plt.plot(x, torch.relu(x), label="relu")
plt.plot(x, F.selu(x), label="selu")
plt.plot(x, F.elu(x), label="elu")
plt.legend()
plot_derivative(x, torch.relu)
plot_derivative(x, F.selu)
plot_derivative(x, F.elu)
plt.legend()
<matplotlib.legend.Legend at 0x159eb4415a0>