Interpretable DNA Motif Detection with a Multi-Scale CNN and Attention

Overview

This tutorial walks through building a convolutional neural network tailored to DNA sequence classification problems such as promoter prediction, splice site detection, and regulatory element identification. The focus is on a reproducible, hands-on pipeline: one-hot encoding of sequences, multi-scale convolutional feature extractors, and an attention mechanism that helps expose which positions drive model decisions.

Key components and motivations

Code and implementation

The tutorial includes full runnable code for data generation, model building, training, and visualization. Below are the original code blocks used in the guide.

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import random


np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

The next block defines the DNASequenceClassifier that handles encoding, model construction with multi-scale convolutions and attention, synthetic data generation, training with callbacks, and visualization.

class DNASequenceClassifier:
   def __init__(self, sequence_length=200, num_classes=2):
       self.sequence_length = sequence_length
       self.num_classes = num_classes
       self.model = None
       self.history = None
      
   def one_hot_encode(self, sequences):
       mapping = {'A': 0, 'T': 1, 'G': 2, 'C': 3}
       encoded = np.zeros((len(sequences), self.sequence_length, 4))
      
       for i, seq in enumerate(sequences):
           for j, nucleotide in enumerate(seq[:self.sequence_length]):
               if nucleotide in mapping:
                   encoded[i, j, mapping[nucleotide]] = 1
       return encoded
  
   def attention_layer(self, inputs, name="attention"):
       attention_weights = layers.Dense(1, activation='tanh', name=f"{name}_weights")(inputs)
       attention_weights = layers.Flatten()(attention_weights)
       attention_weights = layers.Activation('softmax', name=f"{name}_softmax")(attention_weights)
       attention_weights = layers.RepeatVector(inputs.shape[-1])(attention_weights)
       attention_weights = layers.Permute([2, 1])(attention_weights)
      
       attended = layers.Multiply(name=f"{name}_multiply")([inputs, attention_weights])
       return layers.GlobalMaxPooling1D()(attended)
  
   def build_model(self):
       inputs = layers.Input(shape=(self.sequence_length, 4), name="dna_input")
      
       conv_layers = []
       filter_sizes = [3, 7, 15, 25]
      
       for i, filter_size in enumerate(filter_sizes):
           conv = layers.Conv1D(
               filters=64,
               kernel_size=filter_size,
               activation='relu',
               padding='same',
               name=f"conv_{filter_size}"
           )(inputs)
           conv = layers.BatchNormalization(name=f"bn_conv_{filter_size}")(conv)
           conv = layers.Dropout(0.2, name=f"dropout_conv_{filter_size}")(conv)
          
           attended = self.attention_layer(conv, name=f"attention_{filter_size}")
           conv_layers.append(attended)
      
       if len(conv_layers) > 1:
           merged = layers.Concatenate(name="concat_multiscale")(conv_layers)
       else:
           merged = conv_layers[0]
      
       dense = layers.Dense(256, activation='relu', name="dense_1")(merged)
       dense = layers.BatchNormalization(name="bn_dense_1")(dense)
       dense = layers.Dropout(0.5, name="dropout_dense_1")(dense)
      
       dense = layers.Dense(128, activation='relu', name="dense_2")(dense)
       dense = layers.BatchNormalization(name="bn_dense_2")(dense)
       dense = layers.Dropout(0.3, name="dropout_dense_2")(dense)
      
       if self.num_classes == 2:
           outputs = layers.Dense(1, activation='sigmoid', name="output")(dense)
           loss = 'binary_crossentropy'
           metrics = ['accuracy', 'precision', 'recall']
       else:
           outputs = layers.Dense(self.num_classes, activation='softmax', name="output")(dense)
           loss = 'categorical_crossentropy'
           metrics = ['accuracy']
      
       self.model = keras.Model(inputs=inputs, outputs=outputs, name="DNA_CNN_Classifier")
      
       optimizer = keras.optimizers.Adam(
           learning_rate=0.001,
           beta_1=0.9,
           beta_2=0.999,
           epsilon=1e-7
       )
      
       self.model.compile(
           optimizer=optimizer,
           loss=loss,
           metrics=metrics
       )
      
       return self.model
  
   def generate_synthetic_data(self, n_samples=10000):
       sequences = []
       labels = []
      
       positive_motifs = ['TATAAA', 'CAAT', 'GGGCGG', 'TTGACA']
       negative_motifs = ['AAAAAAA', 'TTTTTTT', 'CCCCCCC', 'GGGGGGG']
      
       nucleotides = ['A', 'T', 'G', 'C']
      
       for i in range(n_samples):
           sequence = ''.join(random.choices(nucleotides, k=self.sequence_length))
          
           if i < n_samples // 2:
               motif = random.choice(positive_motifs)
               pos = random.randint(0, self.sequence_length - len(motif))
               sequence = sequence[:pos] + motif + sequence[pos + len(motif):]
               label = 1
           else:
               if random.random() < 0.3:
                   motif = random.choice(negative_motifs)
                   pos = random.randint(0, self.sequence_length - len(motif))
                   sequence = sequence[:pos] + motif + sequence[pos + len(motif):]
               label = 0
          
           sequences.append(sequence)
           labels.append(label)
      
       return sequences, np.array(labels)
  
   def train(self, X_train, y_train, X_val, y_val, epochs=50, batch_size=32):
       callbacks = [
           keras.callbacks.EarlyStopping(
               monitor='val_loss',
               patience=10,
               restore_best_weights=True
           ),
           keras.callbacks.ReduceLROnPlateau(
               monitor='val_loss',
               factor=0.5,
               patience=5,
               min_lr=1e-6
           )
       ]
      
       self.history = self.model.fit(
           X_train, y_train,
           validation_data=(X_val, y_val),
           epochs=epochs,
           batch_size=batch_size,
           callbacks=callbacks,
           verbose=1
       )
      
       return self.history
  
   def evaluate_and_visualize(self, X_test, y_test):
       y_pred_proba = self.model.predict(X_test)
       y_pred = (y_pred_proba > 0.5).astype(int).flatten()
      
       print("Classification Report:")
       print(classification_report(y_test, y_pred))
      
       fig, axes = plt.subplots(2, 2, figsize=(15, 10))
      
       axes[0,0].plot(self.history.history['loss'], label='Training Loss')
       axes[0,0].plot(self.history.history['val_loss'], label='Validation Loss')
       axes[0,0].set_title('Training History - Loss')
       axes[0,0].set_xlabel('Epoch')
       axes[0,0].set_ylabel('Loss')
       axes[0,0].legend()
      
       axes[0,1].plot(self.history.history['accuracy'], label='Training Accuracy')
       axes[0,1].plot(self.history.history['val_accuracy'], label='Validation Accuracy')
       axes[0,1].set_title('Training History - Accuracy')
       axes[0,1].set_xlabel('Epoch')
       axes[0,1].set_ylabel('Accuracy')
       axes[0,1].legend()
      
       cm = confusion_matrix(y_test, y_pred)
       sns.heatmap(cm, annot=True, fmt='d', ax=axes[1,0], cmap='Blues')
       axes[1,0].set_title('Confusion Matrix')
       axes[1,0].set_ylabel('Actual')
       axes[1,0].set_xlabel('Predicted')
      
       axes[1,1].hist(y_pred_proba[y_test==0], bins=50, alpha=0.7, label='Negative', density=True)
       axes[1,1].hist(y_pred_proba[y_test==1], bins=50, alpha=0.7, label='Positive', density=True)
       axes[1,1].set_title('Prediction Score Distribution')
       axes[1,1].set_xlabel('Prediction Score')
       axes[1,1].set_ylabel('Density')
       axes[1,1].legend()
      
       plt.tight_layout()
       plt.show()
      
       return y_pred, y_pred_proba

Finally, the example main workflow demonstrates dataset creation, encoding, splits, model building, training, and evaluation.

def main():
   print(" Advanced DNA Sequence Classification with CNN")
   print("=" * 50)
  
   classifier = DNASequenceClassifier(sequence_length=200, num_classes=2)
  
   print("Generating synthetic DNA sequences...")
   sequences, labels = classifier.generate_synthetic_data(n_samples=10000)
  
   print("Encoding DNA sequences...")
   X = classifier.one_hot_encode(sequences)
  
   X_trn, X_test, y_trn, y_test = train_test_split(
       X, labels, test_size=0.2, random_state=42, stratify=labels
   )
   X_trn, X_val, y_trn, y_val = train_test_split(
       X_trn, y_trn, test_size=0.2, random_state=42, stratify=y_train
   )
  
   print(f"Training set: {X_train.shape}")
   print(f"Validation set: {X_val.shape}")
   print(f"Test set: {X_test.shape}")
  
   print("Building CNN model...")
   model = classifier.build_model()
   print(model.summary())
  
   print("Training model...")
   classifier.train(X_train, y_train, X_val, y_val, epochs=30, batch_size=64)
  
   print("Evaluating model...")
   y_pred, y_pred_proba = classifier.evaluate_and_visualize(X_test, y_test)
  
   print(" Training and evaluation complete!")


if __name__ == "__main__":
   main()

Training tips and interpretability notes

Limitations and practical considerations

Resources

The tutorial includes full codes, notebooks, and visualization examples to reproduce the experiments and adapt the architecture to other sequence-based tasks.