Quantization is an effective technique for significantly reducing the computational requirements of large models by training models with lower precision parameters.
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Standard Full Precision Model
class LeNet(nn.Module):
def __init__(self, in_channels=1, n_outputs=10):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels, 8, kernel_size=3)
self.conv2 = nn.Conv2d(8, 16, kernel_size=3)
self.pool = nn.MaxPool2d(kernel_size=2)
self.fc1 = nn.Linear(16*5*5, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, n_outputs)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device)
print(model.conv1.weight.dtype)
torch.float32
Pytorch offers an easy method to convert to half precision.
def print_model_size(mdl):
torch.save(mdl.state_dict(), "tmp.pt")
print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
os.remove('tmp.pt')
print_model_size(model)
model.half()
print_model_size(model)
0.25 MB 0.13 MB
We also have to make sure the input data is half precision.
transforms = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),
torchvision.transforms.ConvertImageDtype(torch.float16),
torchvision.transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST('../data', train=True, transform=transforms, download=False)
test_dataset = torchvision.datasets.MNIST('../data', train=False, transform=transforms, download=False)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=32,
shuffle=True)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=32,
shuffle=False)
def train(model, optimizer, criterion, train_loader, epochs=1):
model.train()
for i in range(epochs):
train_loss = 0
for _, (inputs, target) in enumerate(train_loader):
inputs = inputs.to(device)
target = target.to(device)
prediction = model(inputs)
loss = criterion(prediction, target)
train_loss += loss.item()
model.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {i}, NLL: {train_loss / len(train_loader.dataset)}")
def test(model, optimizer, criterion, test_loader):
model.eval()
with torch.no_grad():
num_correct = 0
test_loss = 0
for _, (inputs, target) in enumerate(test_loader):
inputs = inputs.to(device)
target = target.to(device)
output = model(inputs)
loss = criterion(output, target)
test_loss += loss.item()
_, predictions = torch.max(output, -1)
num_correct += (predictions == target).sum().data.item()
accuracy = (num_correct / len(test_loader.dataset)) * 100
print(f"Test Accuracy: {accuracy}, NLL: {test_loss}")
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
train(model, optimizer, criterion, train_loader, epochs=1)
test(model, optimizer, criterion, test_loader)
Epoch 0, NLL: 0.041533625284830726 Test Accuracy: 90.52, NLL: 96.3634033203125
Now, most articles on this topic revolve around post-training quantization. Pyorch implements the following techniques.
dynamic quantization (weights quantized with activations read/stored in floating point and quantized for compute)
static quantization (weights quantized, activations quantized, calibration required post training)
static quantization aware training (weights quantized, activations quantized, quantization numerics modeled during training)