Interpretable DNA Motif Detection with a Multi-Scale CNN and Attention
Overview
This tutorial walks through building a convolutional neural network tailored to DNA sequence classification problems such as promoter prediction, splice site detection, and regulatory element identification. The focus is on a reproducible, hands-on pipeline: one-hot encoding of sequences, multi-scale convolutional feature extractors, and an attention mechanism that helps expose which positions drive model decisions.
Key components and motivations
- One-hot encoding preserves the discrete nucleotide identity and feeds the sequence into 1D convolutions.
- Multi-scale convolutions (kernels of different lengths) capture motifs of varying sizes, from short transcription factor binding sites to longer regulatory patterns.
- An attention module highlights informative positions across the sequence, providing interpretability that complements performance metrics.
Code and implementation
The tutorial includes full runnable code for data generation, model building, training, and visualization. Below are the original code blocks used in the guide.
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import random
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)
The next block defines the DNASequenceClassifier that handles encoding, model construction with multi-scale convolutions and attention, synthetic data generation, training with callbacks, and visualization.
class DNASequenceClassifier:
def __init__(self, sequence_length=200, num_classes=2):
self.sequence_length = sequence_length
self.num_classes = num_classes
self.model = None
self.history = None
def one_hot_encode(self, sequences):
mapping = {'A': 0, 'T': 1, 'G': 2, 'C': 3}
encoded = np.zeros((len(sequences), self.sequence_length, 4))
for i, seq in enumerate(sequences):
for j, nucleotide in enumerate(seq[:self.sequence_length]):
if nucleotide in mapping:
encoded[i, j, mapping[nucleotide]] = 1
return encoded
def attention_layer(self, inputs, name="attention"):
attention_weights = layers.Dense(1, activation='tanh', name=f"{name}_weights")(inputs)
attention_weights = layers.Flatten()(attention_weights)
attention_weights = layers.Activation('softmax', name=f"{name}_softmax")(attention_weights)
attention_weights = layers.RepeatVector(inputs.shape[-1])(attention_weights)
attention_weights = layers.Permute([2, 1])(attention_weights)
attended = layers.Multiply(name=f"{name}_multiply")([inputs, attention_weights])
return layers.GlobalMaxPooling1D()(attended)
def build_model(self):
inputs = layers.Input(shape=(self.sequence_length, 4), name="dna_input")
conv_layers = []
filter_sizes = [3, 7, 15, 25]
for i, filter_size in enumerate(filter_sizes):
conv = layers.Conv1D(
filters=64,
kernel_size=filter_size,
activation='relu',
padding='same',
name=f"conv_{filter_size}"
)(inputs)
conv = layers.BatchNormalization(name=f"bn_conv_{filter_size}")(conv)
conv = layers.Dropout(0.2, name=f"dropout_conv_{filter_size}")(conv)
attended = self.attention_layer(conv, name=f"attention_{filter_size}")
conv_layers.append(attended)
if len(conv_layers) > 1:
merged = layers.Concatenate(name="concat_multiscale")(conv_layers)
else:
merged = conv_layers[0]
dense = layers.Dense(256, activation='relu', name="dense_1")(merged)
dense = layers.BatchNormalization(name="bn_dense_1")(dense)
dense = layers.Dropout(0.5, name="dropout_dense_1")(dense)
dense = layers.Dense(128, activation='relu', name="dense_2")(dense)
dense = layers.BatchNormalization(name="bn_dense_2")(dense)
dense = layers.Dropout(0.3, name="dropout_dense_2")(dense)
if self.num_classes == 2:
outputs = layers.Dense(1, activation='sigmoid', name="output")(dense)
loss = 'binary_crossentropy'
metrics = ['accuracy', 'precision', 'recall']
else:
outputs = layers.Dense(self.num_classes, activation='softmax', name="output")(dense)
loss = 'categorical_crossentropy'
metrics = ['accuracy']
self.model = keras.Model(inputs=inputs, outputs=outputs, name="DNA_CNN_Classifier")
optimizer = keras.optimizers.Adam(
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7
)
self.model.compile(
optimizer=optimizer,
loss=loss,
metrics=metrics
)
return self.model
def generate_synthetic_data(self, n_samples=10000):
sequences = []
labels = []
positive_motifs = ['TATAAA', 'CAAT', 'GGGCGG', 'TTGACA']
negative_motifs = ['AAAAAAA', 'TTTTTTT', 'CCCCCCC', 'GGGGGGG']
nucleotides = ['A', 'T', 'G', 'C']
for i in range(n_samples):
sequence = ''.join(random.choices(nucleotides, k=self.sequence_length))
if i < n_samples // 2:
motif = random.choice(positive_motifs)
pos = random.randint(0, self.sequence_length - len(motif))
sequence = sequence[:pos] + motif + sequence[pos + len(motif):]
label = 1
else:
if random.random() < 0.3:
motif = random.choice(negative_motifs)
pos = random.randint(0, self.sequence_length - len(motif))
sequence = sequence[:pos] + motif + sequence[pos + len(motif):]
label = 0
sequences.append(sequence)
labels.append(label)
return sequences, np.array(labels)
def train(self, X_train, y_train, X_val, y_val, epochs=50, batch_size=32):
callbacks = [
keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=10,
restore_best_weights=True
),
keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=5,
min_lr=1e-6
)
]
self.history = self.model.fit(
X_train, y_train,
validation_data=(X_val, y_val),
epochs=epochs,
batch_size=batch_size,
callbacks=callbacks,
verbose=1
)
return self.history
def evaluate_and_visualize(self, X_test, y_test):
y_pred_proba = self.model.predict(X_test)
y_pred = (y_pred_proba > 0.5).astype(int).flatten()
print("Classification Report:")
print(classification_report(y_test, y_pred))
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes[0,0].plot(self.history.history['loss'], label='Training Loss')
axes[0,0].plot(self.history.history['val_loss'], label='Validation Loss')
axes[0,0].set_title('Training History - Loss')
axes[0,0].set_xlabel('Epoch')
axes[0,0].set_ylabel('Loss')
axes[0,0].legend()
axes[0,1].plot(self.history.history['accuracy'], label='Training Accuracy')
axes[0,1].plot(self.history.history['val_accuracy'], label='Validation Accuracy')
axes[0,1].set_title('Training History - Accuracy')
axes[0,1].set_xlabel('Epoch')
axes[0,1].set_ylabel('Accuracy')
axes[0,1].legend()
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d', ax=axes[1,0], cmap='Blues')
axes[1,0].set_title('Confusion Matrix')
axes[1,0].set_ylabel('Actual')
axes[1,0].set_xlabel('Predicted')
axes[1,1].hist(y_pred_proba[y_test==0], bins=50, alpha=0.7, label='Negative', density=True)
axes[1,1].hist(y_pred_proba[y_test==1], bins=50, alpha=0.7, label='Positive', density=True)
axes[1,1].set_title('Prediction Score Distribution')
axes[1,1].set_xlabel('Prediction Score')
axes[1,1].set_ylabel('Density')
axes[1,1].legend()
plt.tight_layout()
plt.show()
return y_pred, y_pred_proba
Finally, the example main workflow demonstrates dataset creation, encoding, splits, model building, training, and evaluation.
def main():
print(" Advanced DNA Sequence Classification with CNN")
print("=" * 50)
classifier = DNASequenceClassifier(sequence_length=200, num_classes=2)
print("Generating synthetic DNA sequences...")
sequences, labels = classifier.generate_synthetic_data(n_samples=10000)
print("Encoding DNA sequences...")
X = classifier.one_hot_encode(sequences)
X_trn, X_test, y_trn, y_test = train_test_split(
X, labels, test_size=0.2, random_state=42, stratify=labels
)
X_trn, X_val, y_trn, y_val = train_test_split(
X_trn, y_trn, test_size=0.2, random_state=42, stratify=y_train
)
print(f"Training set: {X_train.shape}")
print(f"Validation set: {X_val.shape}")
print(f"Test set: {X_test.shape}")
print("Building CNN model...")
model = classifier.build_model()
print(model.summary())
print("Training model...")
classifier.train(X_train, y_train, X_val, y_val, epochs=30, batch_size=64)
print("Evaluating model...")
y_pred, y_pred_proba = classifier.evaluate_and_visualize(X_test, y_test)
print(" Training and evaluation complete!")
if __name__ == "__main__":
main()
Training tips and interpretability notes
- Callbacks: EarlyStopping and ReduceLROnPlateau stabilize training and prevent overfitting.
- Interpretability: The attention modules produce position-wise weights; you can extract and plot these to see which subsequences or motif positions the model attends to.
- Synthetic vs real data: The synthetic motifs are useful for validating the pipeline and for debugging; when moving to experimental genomic data, consider class imbalance, sequence quality, variable-length inputs (use pooling or cropping), and transfer learning from larger datasets.
Limitations and practical considerations
- The example uses simple motif insertion to create signal; biological signals can be noisier and overlapping.
- Attention is helpful but not a guaranteed explanation: validate attention patterns against known motifs or with additional attribution methods (saliency, integrated gradients).
- Hyperparameters (filter sizes, number of filters, dropout rates) should be tuned per task.
Resources
The tutorial includes full codes, notebooks, and visualization examples to reproduce the experiments and adapt the architecture to other sequence-based tasks.