Breaking CAPTCHAs with Deep Learning: A fast.ai Approach
Ever wonder how those distorted text challenges called CAPTCHAs work—and how we might teach computers to solve them? Today we're diving into a practical machine learning project that combines computer vision and sequence recognition to crack this fascinating problem.
The Big Picture: Why CAPTCHAs Matter
Before we get into the code, let's understand what we're really doing here. CAPTCHA recognition isn't just about bypassing security measures (which you shouldn't do without permission!). It's about solving a fundamental AI challenge: teaching machines to perceive and interpret visual sequences much like humans do.
This project demonstrates a powerful pattern in machine learning: breaking complex problems into specialized sub-tasks. Rather than building one massive model to handle everything, we'll use specialized components that excel at different aspects of the problem.
The Architecture: CNN + LSTM + CTC
Our solution uses three key elements:
Convolutional Neural Networks (CNNs) to process the image
Long Short-Term Memory networks (LSTMs) to interpret the sequence
Connectionist Temporal Classification (CTC) to handle alignment uncertainty
Let's see how these pieces work together using fast.ai and PyTorch.
Step 1: Creating Our Dataset
First, we'll generate synthetic CAPTCHAs for training:
from fastai.vision.all import *
from fastai.text.all import *
from datasets import Dataset
from captcha.image import ImageCaptcha
import random
import string
def create_captcha_dataset(size=1000):
generator = ImageCaptcha(width=160, height=60)
data = []
for _ in range(size):
# Generate random 5-character label
label = "".join(random.choices(string.ascii_uppercase + string.digits, k=5))
img = generator.generate_image(label)
data.append({'image': img, 'label': label})
return Dataset.from_list(data)
# Generate 1000 synthetic CAPTCHAs
ds = create_captcha_dataset(1000)
# Display an example
show_image(ds[0]['image'], title=ds[0]['label'], figsize=(3,5))
Step 2: The CNN Backbone - Our Visual Feature Extractor
The CNN acts as our "eyes," extracting meaningful patterns from the raw pixels:
class CNNBackbone(nn.Module):
def __init__(self):
super().__init__()
# First conv block (input: 3x60x160)
self.conv_block1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2) # 60x160 -> 30x80
)
# Second conv block
self.conv_block2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2) # 30x80 -> 15x40
)
# Third conv block
self.conv_block3 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2,1), stride=(2,1)) # 15x40 -> 7x40
)
# Fourth conv block
self.conv_block4 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Dropout2d(0.2),
nn.MaxPool2d(kernel_size=(2,1), stride=(2,1)) # 7x40 -> 3x40
)
# Force consistent output width for the sequence model
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 40))
def forward(self, x):
x = self.conv_block1(x)
x = self.conv_block2(x)
x = self.conv_block3(x)
x = self.conv_block4(x)
x = self.adaptive_pool(x)
return x
Notice how we're gradually reducing the height while preserving the width—this transforms the 2D image into a 1D sequence of features, perfect for our LSTM to process.
Step 3: The CRNN Model - Combining Vision and Sequence Processing
Now we connect the CNN to an LSTM to handle the sequential nature of text:
class CRNN(nn.Module):
def __init__(self, num_chars):
super().__init__()
self.cnn = CNNBackbone()
# CNN output: [bs, 512, 1, 40]
lstm_input_size = 512
hidden_size = 256
self.num_chars = num_chars
# Bidirectional LSTM layers
self.lstm = nn.LSTM(
lstm_input_size,
hidden_size,
num_layers=2,
bidirectional=True,
batch_first=False,
dropout=0.2
)
# Output layer maps to character probabilities
self.fc = nn.Linear(hidden_size * 2, num_chars + 1) # +1 for blank
# Initialize with better values (crucial for performance)
with torch.no_grad():
self.fc.bias[num_chars] = -2.0 # Bias against blank prediction
for i in range(num_chars):
self.fc.bias[i] = 0.5 # Small positive bias for real characters
def forward(self, images):
# Extract CNN features
features = self.cnn(images) # [bs, 512, 1, 40]
bs, C, H, W_seq = features.size()
# Reshape for LSTM: [seq_len, batch_size, features]
features = features.squeeze(2)
features = features.permute(2, 0, 1) # [40, bs, 512]
# Pass through LSTM
lstm_out, _ = self.lstm(features)
# Get character probabilities
logits = self.fc(lstm_out) # [40, bs, num_chars+1]
# Apply log softmax for CTC loss
log_probs = F.log_softmax(logits, dim=2)
return log_probs
Step 4: CTC Loss - The Special Sauce
The key insight: we don't know exactly which parts of the image correspond to which characters. CTC loss solves this alignment problem:
class CTCLossFlat(nn.Module):
def __init__(self, blank_token=0, pad_token=-1):
super().__init__()
self.blank_token = blank_token
self.pad_token = pad_token
self.ctc = nn.CTCLoss(blank=blank_token, reduction='mean', zero_infinity=True)
def forward(self, log_probs, targets):
# log_probs: [T, B, C] - Time steps, Batch size, Classes
# targets: [B, S] - Batch size, Sequence length (padded)
T, B, C = log_probs.shape
# Input lengths = full time steps for each batch item
input_lengths = torch.full((B,), T, dtype=torch.long, device=log_probs.device)
# Target lengths = number of non-pad tokens in each label
target_lengths = (targets != self.pad_token).sum(dim=1)
# Flatten targets for CTC
targets_flat = torch.cat([t[t != self.pad_token] for t in targets])
return self.ctc(log_probs, targets_flat, input_lengths, target_lengths)
def decodes(self, x):
# Convert model output to text predictions using CTC decoding rules
if x.ndim == 3:
x = x.permute(1, 0, 2) # [B, T, C]
preds = x.argmax(-1) # [B, T]
decoded = []
for pred in preds:
tokens = []
prev = self.blank_token
# Remove duplicates and blanks (CTC decoding rules)
for p in pred.cpu().tolist():
if p != prev and p != self.blank_token:
tokens.append(p)
prev = p
decoded.append(tokens)
# Pad and wrap in TensorText
max_len = max(len(seq) for seq in decoded)
padded = torch.full((len(decoded), max_len), self.pad_token, device=x.device)
for i, seq in enumerate(decoded):
padded[i, :len(seq)] = torch.tensor(seq, device=x.device)
return TensorText(padded)
Step 5: Training with fast.ai
Now we bring everything together using fast.ai's high-level APIs:
# Process labels for our model
class TokenizeLabel(Transform):
vocab = list(string.ascii_uppercase + string.digits)
def __init__(self):
self.stoi = {v: k for k, v in enumerate(self.vocab)}
def encodes(self, x: str):
return TensorText(tensor([self.stoi[c] for c in x]))
def decodes(self, x: TensorText):
indices = x.detach().cpu().flatten().tolist()
return CaptchaStr(''.join(self.vocab[int(i)] for i in indices))
# Create DataBlock and DataLoaders
dblock = DataBlock(
blocks=(ImageBlock, TransformBlock(type_tfms=TokenizeLabel())),
get_x=lambda o: o['image'],
get_y=lambda o: o['label'],
splitter=RandomSplitter(),
batch_tfms=[Normalize()]
)
dls = dblock.dataloaders(ds, bs=16)
# Initialize model and loss
model = CRNN(len(dls.vocab))
loss_func = CTCLossFlat(blank_token=0)
# Apply different learning rates to different components
def split_params(model):
return [
params(model.cnn), # CNN - lower learning rate
params(model.lstm), # LSTM - medium learning rate
params(model.fc) # Final layer - higher learning rate
]
# Create Learner and train
learn = Learner(
dls,
model,
loss_func=loss_func,
splitter=split_params,
metrics=[CTCAccuracy(CTCDecoder(dls.vocab))],
wd=1e-3
)
# Find good learning rate
learn.lr_find()
# Train with one-cycle policy
learn.fit_one_cycle(20, 1e-3)
Results: From 0% to 85% Accuracy
The results are impressive! Our model progresses from complete confusion to 85% accuracy in just 20 training epochs:
epoch train_loss valid_loss acc time
0 5.699660 3.764331 0.000000 00:02
...
19 0.098768 0.001344 0.855000 00:02
Let's visualize our model's predictions:
The model correctly identifies most CAPTCHAs, even with significant distortion and overlapping characters!
The Machine Learning Mindset
What makes this approach successful isn't just the specific code, but the problem-solving pattern:
Divide and conquer: Separate image recognition from sequence interpretation
Choose specialized tools: CNNs for visual features, LSTMs for sequences
Handle uncertainty: Use CTC loss to manage the alignment problem
Smart initialization: Bias the model against blank predictions to avoid "blank collapse"
Fine-tune learning rates: Use different rates for different components
This same pattern can help solve many complex machine learning challenges:
Medical time series analysis
Want to improve this model? Consider:
Data augmentation (rotation, noise, blur) for more robust training
Trying transformer architectures instead of LSTM
Exploring beam search during decoding
Testing on real-world CAPTCHAs (ethically, of course!)
Remember, the most valuable skill in machine learning isn't memorizing architectures—it's knowing how to decompose problems and connect specialized components in ways that leverage their strengths.
What machine learning challenge are you working on? Let me know in the comments!
Processing medical imaging data with annotations
Thinking Like a Machine Learning Engineer
The next time you face a complex ML problem, remember this approach:
What are the sub-problems?
What type of neural network is best for each part?
How do these parts need to talk to each other?
What's the right way to measure success?
By breaking down big problems into manageable pieces, even the most intimidating challenges become approachable.
So next time a CAPTCHA asks you to "prove you're human," you can smile knowing that the line between human and machine intelligence continues to blur—one squiggly character at a time.
What machine learning challenge are you tackling? Share in the comments below!