1 Introduction

Drug discovery and development is an extremely long and expensive process that aims to find innovative medical compounds ready to be formulated, synthesized and administered to a patient [1]. Despite the most recent scientific advances and ever-increasing understanding of biological systems, most of these compounds fail to be selected due to a lack of desirable molecular properties [2].

In lead optimization, only a small fraction of the molecules can pass virtual screening and enter clinical development. The higher the quality of these preclinical candidates, the higher the probability of successful drug development [3]. However, the major cost of this operation stems from exploring the entire chemical space to synthesize only a few drug candidates. Thus, the search for new classes of compounds with a suitable pharmacological profile from a small amount of labeled data is paramount [4].

Currently, artificial intelligence assists almost every step of drug discovery including target identification, lead discovery and optimization or preclinical data generation. These methods reduce the number of iterations required to discover novel and active compounds while eliminating those that are inactive, reactive and toxic [5].

With the evolution of artificial intelligence, developments in deep learning (DL) have played a crucial role in optimizing drug discovery. These algorithms motivated the application of new graph representation learning techniques to model systems of drug interaction and prediction. However, with only a few labeled molecules available, deep networks struggle to generalize well and achieve acceptable performance [6].

Well-validated biological datasets (e.g., Tox21, SIDER) [7] are limited in size and very expensive to obtain. These scarce drug repositories include only a few compounds that share the same set of molecular properties. The resulting lack of biological information, including molecules sharing similar properties, bounds the performance of conventional approaches. This precondition sets the challenge of developing models to effectively predict small molecules in few-shot learning scenarios [8, 9].

Recent research has demonstrated that simple machine learning algorithms and random forest predictors are effective in learning meaningful structural information from just a few labelled compounds [10, 11]. On the other hand, transfer learning and data augmentation techniques also provide the domain knowledge required in cases where examples with supervised information are hard or impossible to obtain [12, 13].

Nonetheless, these techniques are often too expensive and resource intensive to perform in drug discovery campaigns. More recently, non-trivial few-shot learning predictors have been proposed to discover the properties of new molecules and recognize potential drug candidates for further development [14, 15]. These methods attempt to learn from a set of molecular property prediction tasks and generalize to new chemical properties given a just a few molecules available.

Small molecules can be viewed as comprehensive graph structures, where atoms are represented as nodes and chemical bonds as edges shared by neighbors in a graph [16,17,18]. These graph-level representations account for the spatial arrangement of atoms and bonds as well as interactions between neighboring nodes and edges. This approach is more suitable for representation learning than sequence-based methods that describe molecules as sequential features such as SMILES (Simplified Molecular Input Line Entry System) strings [19].

These unique graph features can be used by deep learning pipelines, which fail to predict molecular properties with limited available data. This limitation prompts the need to explore models that quickly adapt across tasks to predict new properties on few-shot data [20, 21].

2 Related work

Few-shot learning methods have emerged as critical tools to accelerate and optimize drug discovery. These are algorithms that target at generalizing from small data collections to predict new systems from a limited amount of labeled information. Recently, few-shot models have proven effective in modeling molecules as comprehensive graph structures used for graph-based representation learning. Graph neural networks leverage this information to build molecular embeddings by treating atoms as nodes and chemical bonds as edges. Node and edge embeddings can later be used to support the prediction of molecular properties. Deep networks such as convolutional neural networks also manipulate these continuous vectorial representations to encode molecular graphs in a form suitable for few-shot molecule prediction.

2.1 Few-shot learning

Humans have an innate ability to recognize new objects and representations quickly from just a few examples. Prior knowledge helps to distinguish new concepts based on a generalized perception of an extensive and diverse set of representations. Thus, the ability to few-shot learn different representations by observing a concept and generating meaningful and diverse variations is very important when classifying new instances of unknown concepts [22].

This strategy attempts to adapt from previously seen classes to predict unseen representations from just a few labeled examples. This idea of few-shot generalization gave rise to few-shot learning methods [23].

Few-shot learning (FSL) was introduced by Fei-Fei et al. [24] in the field of computer vision and image processing. This approach presents a fundamental feature, which is the ability to predict based on prior experience by transferring knowledge across tasks.

In drug discovery, molecular property prediction is a few-shot learning problem since only a few molecules can pass virtual screening to be further evaluated in lead optimization. At this stage, few-shot models attempt to learn a predictor from a set of molecular property tasks and generalize to new chemical properties from a small amount of labelled molecules.

Altae-Tran et al. [25] introduced an iteratively refined LSTM (IterRefLSTM) model and adapted a classic FSL algorithm to handle these molecular property prediction tasks by iteratively coevolving undirected graph features. In this pioneering study, a graph neural network acts as a molecular encoder and learns few-shot representations across tasks to provide an inductive bias from prior experience that guides the search for the optimal model parameters.

2.2 Graph neural networks

More recently, the fast adaptability of few-shot learning methods foregrounded graph neural networks (GNNs) as promising strategies to model nonlinear systems of molecule prediction.

GNNs have shown promising results by representing molecules as topological graphs of node and edge features. Graphs are characterized by specific neighborhood aggregation functions to update node representations and iteratively build graph-level embeddings.

Graph embedding methodologies, such as graph convolutional networks (GCN) or node2vec, are able to perform feature representation learning to obtain comprehensive graph embeddings. While GCNs use node and edge aggregation to compute graph embeddings, node2vec is inspired by powerful natural language processing algorithms (word2vec) to explore the relationship between nodes and edges in a graph as words in a sentence [26].

In drug discovery, the generated embeddings can be used to learn and predict properties of molecular graph features representing molecules. Hu et al. [27] propose novel strategies to pretrain graph-based networks and assist the learning of local and global information for molecular property prediction using graph embeddings.

GNNs such as GCN, graph isomorphism networks (GIN), GraphSAGE and graph attention networks (GAT) have significantly improved the discovery of new chemical entities with enhanced therapeutic properties [28, 29]. Graph convolution architectures include a convolutional component to aggregate nodes and edges from close neighbors on the same receptive field. This type of graph networks aggregates node features by applying convolutional filters on the aggregated nodes. GraphSAGE uses a different sampling strategy by transforming the original graph convolution training procedure into a training method divided in small batches centered on the nodes. An inductive graph convolution operation also extends the aggregation to generate relevant graph features on unseen nodes and edges [30, 31]. GAT models are an extension of the conventional GCN that perform node aggregation by adding specific attention weights to certain nodes. Thus, GAT computes node aggregates based on attention scores that translate the most meaningful parts of a node’s neighborhood [32]. Among graph neural network architectures, GIN presents the maximum discriminative power by generalizing the Weisfeiler–Lehman (WL) isomorphism test to capture different graph structures [33].

Guo et al. [34] proposed a meta-learning framework (Meta-GNN) across GNNs to predict molecular properties using a novel pre-training technique. Self-supervised learning objectives such as bond reconstruction and atom type prediction were included to improve generalization on few-shot data.

Structural aspects of these graph features model complex patterns and local dependencies between neighboring nodes and edges. Node and edge features can be later used to predict the behavior of active compounds or specific molecular properties of multiple drug candidates.

2.3 Convolutional neural networks

Convolutional neural network (CNN) architectures can process molecules to build up features using learnable convolutional layers [35]. These architectures consist of a set of layers with multiple neurons and a set of parameters representing the strength of the connections between neurons [36].

CNNs learn by applying a backpropagation algorithm to tell the neural network how to change its internal parameters. This shift allows CNNs to compute the representation in a layer from the representation of the previous layer. As we progress to deeper layers, the ability to distinguish global patterns and local dependencies increases [37].

Sequential layers treat the input as an image viewed as a grid where each element corresponds to a pixel. Convolutional weight matrices act as filters that operate in local receptive fields of neighboring pixels to develop structure in high-dimensional feature spaces.

In drug discovery, molecular features can be handled by convolutional architectures to model nonlinear functions of molecular structure and transform small molecules into deep representations. To explore the potential of such representations, Shi et al. [38] proposed a CNN method to establish prediction models of molecular properties for automated virtual screening and ADMET prediction.

Moreover, Ståhl et al. [35] introduced a flexible CNN architecture to incorporate global and local information of deep representations to effectively predict undirected graph features representing molecules.

By converting these molecular features into continuous embedded descriptors, CNNs can use deep representations to infer the complex concepts of molecular structures and boost the prediction of molecular properties [39, 40].

3 Proposed approach

Few-shot molecular property prediction depicts a practical learning scenario where a model \(f\) is provided with a small number of training examples. This model intends to predict new molecular properties from just a few examples by generalizing from prior learning experience.

In this direction, the problem of low data can be defined to the following extent: Given the molecular property labels \(y_i \in Y\) of a small set of molecules \(D\), the goal is to learn a function \(f\) to map a molecule \(d_i \in D\) to a given molecular property \(y \in Y\) in the test data. This can be formalized by

$$\begin{aligned} f: d_i \mapsto y_{i}. \end{aligned}$$
(1)

GNNs have proven critical when optimizing the prediction of drug-like properties by exploring unique graph features [41]. Generally, the objective is to obtain an embedding for a set of nodes and edges in a graph to generalize from node-edge representation learning. Subsequently, these continuous vectorial representations are fed into a simple classifier to predict a given property or molecular label.

Conventionally, a GNN \(f\) converts a graph representation of a node \(v \in V\) to an embedding \(h\) based on the contextual information of a subgraph of its neighboring nodes \(u \in N(V)\) and edges \(e = (v,u)\). These embedded descriptors preserve the original graph structure and properties to map them to a reduced space of comprehensive representations (see Fig. 1).

Fig. 1
figure 1

Graphical depiction of a molecule representation as a molecular graph. Node embeddings \(h_v\) are mapped from molecular graphs to a low-dimensional feature space of vectorial embeddings

However, most of the aforementioned graph-based networks neglect the practical use of such representations and fail to leverage node and edge embeddings to enhance the prediction of molecular properties. Hence, the potential of these graph embeddings has yet to be fully explored successfully.

To address the limitations of existing studies and exploit the singular properties of graph-level representations, we propose an innovative approach, FS-GNNConv, for molecular property prediction. The proposed GNN-CNN architecture explores the local dependencies of molecular graph embeddings to learn complex patterns and produce stronger features for representation learning. The challenge of low-data is systematically addressed by proposing a two-module meta-learning framework to quickly adapt to new molecular properties across few-shot tasks. This strategy saves resources and promotes fast adaptation to new experimental tasks, similar or marginally identical to those found in training. The major contributions of this work can be summarized as follows

  1. (i)

    a two-module GNN-CNN architecture that accepts the compound chemical structure to exploit the rich information of graph embeddings;

  2. (ii)

    a few-shot learning (FSL) strategy to learn from task-transferable knowledge and predict the behavior of active compounds in new experimental systems;

  3. (iii)

    a meta-learning framework to iteratively optimize model parameters and successively gather generic knowledge across tasks to predict task-specific molecular properties;

  4. (iv)

    experiments on real multiproperty prediction data to demonstrate the predictive power of the proposed model when inferring specific target properties adaptively.

3.1 Embedding module: graph isomorphism network

GNNs are extremely effective in modeling molecular properties by describing the molecular structure as a graph.

In a molecular graph, each node represents an atom, and each edge represents a chemical bond between atoms. Both can be described by multiple features encoding structural and stereochemical attributes.

Let a molecular graph be defined as \(G = (V,E)\) with \(V\) as the set of nodes and \(E\) as the set of edges \(e = (v,u)\) connecting each pair of nodes in a neighborhood \(N(V)\), where \(v \in V, u \in N(v)\). We denote \(M = \{G_1\ldots, G_N\} \) as the set of molecular graphs and \(Y\) as the set of molecular property labels \(Y = \{y_1\ldots,y_N \} \). The objective is to predict a molecular property \(y_i\) by learning a nonlinear function \(f\) that maps a molecular graph \(G\) to an embedding \(h_{G}\),

$$\begin{aligned} f: G \mapsto h_{G} \end{aligned}$$
(2)

where \(h_G\) is the graph-level embedding used to assist the prediction.

Recent research in this field suggests two main groups of GNNs based on the neighborhood aggregation function for graph embedding. Spectral GNNs decompose each graph to approximate the spectral filters of GNN message-passing layers. On the other hand, spatial GNNs compute the neighborhood aggregation from the node spatial relations between neighboring nodes and edges in the graph [42].

Spatial-based GNNs operate under two arbitrarily differentiable functions: a neighborhood aggregation function AGGREGATE and a COMBINE step to merge and update node and edge features.

GNNs are iterative message-passing networks. Multiple message-passing iterations \(l\) update nodes \(h_v^l\) using prior representations of a node \(h_v^{l-1} \) and representations of its close neighbors in the graph \(h_u^{l-1} \). During a message-passing iteration, node embeddings \(h_v\) with \(v \in V\) are updated using representations of neighboring nodes \(u \in N(v)\) and edges \(e = (v,u)\) [43].

Fig. 2
figure 2

Graphical depiction of the proposed neural network architecture

In this work, we apply a spatial-based GIN [33] as the first module to compute graph embeddings for further learning and prediction. In this case, the GNN implements both COMBINE and AGGREGATE functions as the sum of node and edge features [27]. Thus, node embeddings are updated for each message-passing iteration \(l\) by

$$\begin{aligned} m_{N(v)}^{l}= & {} \text {AGGREGATE}^{l}( \{h_u^{l-1}, \nonumber \\{} & {} \forall u \in N(v)\}, \{ h_e^{l-1}: e = (v,u) \}) \end{aligned}$$
(3)
$$\begin{aligned} h_{v}^{l}= & {} \sigma (\text {MLP}^{l} (\text {COMBINE}^{l} (h_{v}^{l-1}, m_{N(v)}^{l}))) \end{aligned}$$
(4)

where \(m\) is the “neural message” passed through the network, \(h_u^l\) are the embeddings of neighboring nodes, and \(h_e^l\) is the feature vector of an edge between nodes \(u\) and \(v\). An UPDATE step includes a multi-layer perceptron \(MLP\) to introduce nonlinearity and a nonlinear activation function \(\sigma \) (ReLU). More specifically, the GNN updates node representations by

$$\begin{aligned} h_{v}^{l}= \, {} \text {ReLU}\left( \text {MLP}^{l}\left( \sum _{u\in N(v) \cup {v}} h_{u}^{l-1} + \sum _{e = (v,u): u\in N(v) \cup {v}} h_{e}^{l-1}\right) \right) . \end{aligned}$$
(5)

The idea is that a message is generated from the information about neighboring nodes and combined with previous embedded representations of a node \(v\), \(h_{v}^{l-1}\) to obtain the updated embedding \(h_{v}^{l}\). The original inputs for aggregation are the initial node and edge attributes \(h_v^0\) and \(h_e^0\). After \(l\) iterations, the final embedding, \(h_v^l, \forall v \in V \), incorporates information about the node (atom) and the contextual subgraph of nodes and edges (bonds between atoms) in a \(l\)-hop neighborhood [44].

Finally, a READOUT mean-pooling operation is performed to obtain an embedding \(h_{G}\). This graph-level representation is obtained by averaging the node embeddings \(h_v\) at the final message-passing layer \(l\)

$$\begin{aligned} h_{G} = \text {mean}(\{h_{v}^{l} : v \in V\}). \end{aligned}$$
(6)

These graph embeddings can be saved for further learning, as we will present in the next section.

A graphical depiction of the proposed model architecture is shown in Fig. 2. For graph operations, the nodes being operated on are displayed in blue, with neighboring nodes shown in black. For AGGREGATE, COMBINE and UPDATE, the operations are shown for a single node \(v \in V \) and performed on all nodes \(v\) in the graph, simultaneously. In this case, we consider graph operations for \(L = 5\) message-passing layers and the READOUT operation is performed at the final layer. In the convolution operations, \(h_{\rm{conv}}\) describes the deep representations extracted by a CNN \(g\) from graph embeddings of \({\text{size}} = 300\). Different blue squares denote different values of the node-level and graph-level embeddings \(h_v\) and \(h_G\), respectively. Different orange squares denote different values of deep representations \(h_{\rm{conv}}\).

Node and edge features are described by atom and bond attributes. Node attributes include the atom number (AN) and atom chirality (AC) to describe the type of atom, how it is connected and the spatial interaction behavior with neighboring nodes. Edge attributes include bond type (BT) and bond direction (BD) to specify the structural aspects of chemical bonds and the spatial orientation of the edges. Formally, the input node and edge features are described as

$$\begin{aligned} h_{v}^{0}= &\, {} (v_{\text {AN}}, v_{\text {AC}}) \end{aligned}$$
(7)
$$\begin{aligned} h_{e}^{0}= &\, {} (e_{\text {BT}}, e_{\text {BD}}) \end{aligned}$$
(8)

with \((,) \) as a concatenation operation and \(e\) and \(v\) as edge and node attributes, respectively.

Pre-trained GIN, GCN and other graph-based architectures have been widely used in drug discovery applications. In this sense, the GIN model is pre-trained using a recent pre-training technique [27] to achieve better parameter initialization and learn global and more generic descriptors.

3.2 Prediction module: convolutional neural network

CNNs inherit many of the properties of the artificial neural networks to develop structures in a feature space where the complexity stratifies along with different layers. These layers are made up of neurons including a set of learnable weights and biases.

Convolutional layers consist of several convolutional filters (weight matrices) capable of extracting local features from the input. In the forward pass, these representations are propagated across convolutional layers while convolutional filters slide through the spatial dimensions of the input. The output feature maps are the result of the convolution operation between the convolutional filters and different positions in the input vector [29].

These sequential layers emerge as detectors of local patterns by restricting the connections with neurons to small regions of the input. The subsequent increase in the number of convolutional filters and in the combinatorial size of the feature space allows the extraction of more complex and generic descriptors [37].

In this study, one-dimensional CNNs are explored to build a molecular property prediction module. By incorporating the graph structure information enclosed in graph embeddings, CNNs can discriminate important patterns between close and distant neighbors in the graph.

To this end, node representations \(h_v\) are used to calculate graph embeddings \(h_G\) through node averaging and mean pooling, as described in the previous section. Embeddings are then used as input feature vectors for further computation.

First, we collect the embeddings \(H = \{h_{G_1}\ldots, h_{G_N}\}\) obtained from the original graphs \(G_i\). Then, we perform the nonlinear mapping of \(h_{G}\) to a deep representation vector \(h_{\rm{conv}}\) using a CNN \(g\). This relation can be defined by

$$\begin{aligned} g: h_{G} \mapsto h_{\rm{conv}}. \end{aligned}$$
(9)

Deep representations \(h_{\rm{conv}}\) are then propagated to become increasingly smaller and more complex as we progress to deeper layers.

The prediction module \(g\) presents a conventional architecture with 3 CNN blocks. One-dimensional input embeddings of \({\text{size}} = 300\) are convoluted with filters of size \(3 \times 1\) followed by batch normalization with \({\text{momentum}} = 1, \epsilon = 1e-5\) and ReLU activation (see Table 1).

Table 1 CNN architecture details

The convolution operation consists in a set of multiplication operations and later sum. The convolution between convolutional filters and the elements of input embeddings \(h_G\), over which these filters slide, returns the feature map of convolutional outputs \(h_{\rm{conv}}\). This process is repeated across layers to return the output of the convolution operation between weight matrices \(W\) and the regions to which they are connected in the input vector (see Fig. 3) [37].

Fig. 3
figure 3

Schematic of the one-dimensional (1D) convolutional operation

After convolution, batch normalization helps to coordinate the updates of multiple layers in the CNN module by scaling the output feature maps of convolutional layers. This is done by normalizing the activations of each input variable per mini-batch, such as the neural activations from previous layers. This process of standardization stabilizes and speeds up model convergence while reducing the generalization error [45].

Then, ReLU (Rectified Linear Unit) units apply a function that converts the values present in the input vectors of elements \(x\) to non-negative values: \({\max}(0,x)\). This operation sets a threshold at 0, where the negative values are nullified, accelerating the process of convergence due to the linear profile of ReLU activation [46].

Hence, the feature maps for each \(l\)-th convolutional layer take the form

$$\begin{aligned} h^{l} = \text {max}(0,\text {BatchNorm}(W^{l-1,l}* h^{l-1} + b^l)) \end{aligned}$$
(10)

where \(W^{l-1,l}\) is the weight matrix connecting units of layer \(l-1\) to units in layer \(l\) and \(b^l\) the bias vector for a layer \(l\). \(*\) is the convolution operation to return the output units for neural activation and \(h^{l}\), \(h^{l-1}\) are the hidden vectors in layers \(l\) and \(l-1\), respectively.

Finally, a dense fully connected layer (FC) followed by sigmoid activation uses the output of the last convolutional layer \(l = 3\) to compute the prediction (condensed in a value \(\in \{0,1\}\)).

Fig. 4
figure 4

Graphical representation of the CNN module architecture

The CNN module treats graph-level embeddings as images. Each embedding is viewed as a grid and elements in the same receptive field are convoluted to build a representation that accounts for the dependencies between neighboring nodes and edges. Local connections between close and distant regions of graph embeddings are then explored to compute increasingly smaller and complex representations.

The convolution of graph embeddings \(h_G\) transforms molecular graphs into deep vectorial representations \(h_{\rm{conv}}\). By propagating such representations across convolutional layers, CNNs model complex patterns and local dependencies within molecular structures as a function for molecular property prediction.

A detailed representation of the CNN prediction module is depicted in Fig. 4.

4 Training and inference

In this section, we introduce a few-shot meta-learning framework on the basis of model-agnostic meta-learning (MAML) Finn et al. [47] to predict new molecular properties across tasks.

Meta-learning relies on prior knowledge to systematically revisit previous learning episodes and define a promising strategy based on experience. Fewer examples are required with an increasing number of learning tasks making the process faster and more efficient. This type of non-trivial learning adapts and generalizes to new representations from limited available data [48, 49].

From this perspective, we address the challenge of low data by optimizing the proposed model across several learning tasks. In meta-training, the model is trained on a labeled support set and evaluated on a disjoint query set for each task. Previous learning episodes are then used to predict new tasks and optimize the algorithm to the new task at hand. This process is repeated across tasks to return a predictor that generalizes well to unseen representations in few-shot data.

The goal is to predict a molecular property (e.g., toxicity, side effects) of a query molecule \(x\) so that \(\{f_{\theta }(x), g_{\theta ^*}(h(x))\}:M\Rightarrow \{0,1\}\in Y\), where \(M\) is the space of all molecular graphs \(G\), \(h(x)\) is the output embedding from a GNN \(f_{\theta }\), \(g_{\theta ^*}\) is a CNN and \(Y\) are the molecular property labels.

In this study, we train two meta-models, a GNN \(f_{\theta }\) and a CNN \(g_{\theta ^*}\) with parameters \(\theta \) and \(\theta ^*\) across tasks \(t\) from a distribution \(\rho (T) \). Meta-training and meta-testing sets are sampled for each molecular property task \(t\). Both include a support set \(S\) for training and a query set \(Q\) for evaluation. For each task \(t\), both models are parameterized by a task-specific support set \(S_t\) and query set \(Q_t\) for each task.

Particularly, under a \(k\)-shot meta-training, both models \(f_\theta \) and \(g_{\theta ^*}\) with parameters \(\theta \) and \(\theta ^*\), are constantly updated, trained on \(S_t\) and evaluated on \(Q_t\) for each task.

First, for each \(k\)-shot task \(t\), \(k\) support samples \(G_{S_{t_i}}\) are randomly sampled and fed into the GNN-CNN two-module architecture to compute the support losses \(\mathcal {L}^{\rm{gnn}}_{t}\) and \(\mathcal {L}^{\rm{conv}}_{t}\).

Subsequently, the support losses are used to update the model parameters \(\theta \rightarrow \theta '\), \(\theta ^{*} \rightarrow \theta ^{*'}\), and both models are evaluated on a query set \(Q_t\) to compute the query losses \(\mathcal {L}^{{\rm{gnn}}'}_{t}\) and \(\mathcal {L}^{{\rm{conv}}'}_{t}\) using the remaining \(n\) samples for each task.

Fig. 5
figure 5

Schematic of the proposed meta-learning framework for few-shot molecular property prediction

In practice, we update the model parameters to adapt to a new task \(t\). Tasks include \(k\) support samples to compute the support losses for both modules, \(\mathcal {L}^{\rm{gnn}}_{t}(\theta )\) and \(\mathcal {L}^{\rm{conv}}_{t}(\theta ^*)\). In this process, we apply a few gradient steps

$$\begin{aligned} \theta _{t}= &\, {} \theta - \alpha \triangledown _\theta \mathcal {L}^{\rm{gnn}}_{t}(\theta ) \end{aligned}$$
(11)
$$\begin{aligned} \theta ^*_{t}= &\, {} \theta ^* - \alpha ^* \triangledown _{\theta ^*} \mathcal {L}^{\rm{conv}}_{t}(\theta ^*) \end{aligned}$$
(12)

with \(\alpha \) and \(\alpha ^*\) as the step sizes for the gradient descent updates.

In meta-testing, we randomly sample a support set of size \(k\) for a new task \(t\) to optimize the model parameters \(\theta \rightarrow \theta '\), \(\theta ^* \rightarrow \theta ^{*'}\) through a few gradient descent steps. Molecular properties are finally predicted in a disjoint query set with the remaining samples. Thus, the updated model parameters are used to generalize to new samples and compute deep representations \(h_{\rm{conv}}\) to assist the prediction of molecular properties.

In this framework, it is expected to obtain optimized parameters that map the molecular graphs to different task-specific properties. The goal is to generalize well to new tasks in the test data after a few gradient steps. A schematic of the proposed meta-learning framework is shown in Fig. 5.

When training the model on a specific data collection, tasks or assays are divided in training and testing tasks. As stated previously, training consists in a set of learning episodes. For each episode, a task from the training tasks is randomly sampled and a support set of size \(n_{+}\) + \(n_{-}\) (with \(n_{+}\) positive and \(n_{-}\) negative samples) and a batch of queries are randomly sampled from that task. Each learning episode takes a few gradient descent steps to minimize the loss function using Adam optimizer.

The accuracy of the model is evaluated separately for each test task at test time. For each test task, a support set of size \(n_{+}\) + \(n_{-}\) is randomly sampled from the data from that task. ROC-AUC scores are then evaluated on the remaining data for that task.

In the results Sect. 5.3, we use the notation \((5+,5-)\) and \((10+,10-)\) to represent \(n_{+} = 5\), \(n_{-} = 5\) and \(n_{+} = 10\), \(n_{-} = 10\), respectively. Appendix Sect. 7 provides further details about the tasks considered for each data collection.

4.1 Cost-sensitive loss for imbalanced classification

The loss for both modules \(\mathcal {L}^{\rm{gnn}}\) and \(\mathcal {L}^{cnn}\) is the binary cross-entropy loss over the predicted properties \(y'\) and the molecular property ground-truth labels \(y\) with \(k\) as the number of samples,

$$\begin{aligned} \mathcal {L} = - \frac{1}{k} \sum _{i=1}^{k} y_i \ \log (y_i') + (1-y_i) \ \log (1-y_i'). \end{aligned}$$
(13)

However, the problem of class imbalance in few-shot data prevents us from obtaining superior performance for either the Tox21 or SIDER benchmarks. To address this issue, we introduce a customized version of binary cross-entropy loss to establish a weight for the minority class as a weighted version of the original objective. This customized loss function takes into account the distribution of each class to penalize failed predictions for rare instances, which greatly impact the loss value.

The weighted version of binary cross-entropy defines a weight \(p\) for the minority class,

$$\begin{aligned} \mathcal {L} = - \frac{1}{k} \sum _{i=1}^{k} p \ y_i \ \log (y_i') + (1-y_i) \ \log (1-y_i') \end{aligned}$$
(14)

with \(p\) defined as the ratio between positive and negative samples. For instance, if a dataset contains 100 positive and 500 negative examples of a single class, then \(p\) for the class should be equal to \(\frac{500}{100}=5\). Since different tasks present different positive/negative distributions, we determine \(p\) by exploring multiple values between 1 and 50 and selecting those that return superior performance (\(p = 35\) for Tox21 and \(p=1\) for SIDER due to task variability).

5 Experiments

The present section reports multiple experiments on two benchmark datasets (Tox21 and SIDER).

The Tox21 dataset comprises qualitative toxicity measurements for 7831 compounds for 12 biological targets including nuclear receptors (NR) and stress response pathways (SR). Each sample represents a compound with 12 binary labels for 12 toxicology experiments [50].

Tox21 is a machine-learning challenge formerly won by a multitask learning approach across deep networks. The main goal is to predict the toxicity of small molecules for a specific NR or SR. In a few-shot learning setting, we use a different proportion of training and testing tasks, disregarding the original train-test split. For a total of 12 tasks, the data were split into 9 tasks for training and 3 for testing (see Table 2).

Table 2 Tox21 comprises qualitative toxicity measurements related to 12 biological targets

SIDER is a collection of 1427 well-validated drugs and adverse drug reactions (ADRs) grouped into 27 system organ classes [51]. SIDER data are extracted from several public articles and publications containing labeled information on marketed drugs, including side effect frequencies, drug-target interactions and drug/side effect classification. The goal is to predict whether a compound triggers a side effect for 27 organ systems. For a total of 27 tasks, the data were split into 21 tasks for training and 6 for testing (see Table 3).

Table 3 SIDER includes a database of marketed medicines grouped into 27 different system organ classes

Both datasets are broadly distinct and represent diverse collections of molecular scaffolds. Thus, it is expected that models perform differently on both data collections.

Raw data of molecules are given in the form of SMILES strings and converted into molecular graphs using the Python library Rdkit.Chem [52]. These SMILES strings are converted into node and edge features to best describe the structure and spatial arrangement of the molecules used in the experiments.

5.1 Baselines

The proposed FS-GNNConv model is compared with three graph-based models:

  1. (i)

    GIN: Pre-trained version of GIN.

  2. (ii)

    GCN: Pre-trained GCN model. The GNN includes a convolutional component for node aggregation. Nodes are seen as pixels, and neighbors in the same receptive field are used to compute node embeddings as the output of the convolution [53].

  3. (iii)

    GraphSAGE: Pre-trained GraphSAGE model. Graph-based network that samples and aggregates neighboring embeddings to leverage relevant graph features. This is an inductive framework that exploits these node attributes to efficiently generate representations on previously unseen data [30].

All baselines were pre-trained with the GCN, GIN and GraphSAGE models of Hu et al. [27] to improve performance. A meta-learning framework was applied to all baselines to achieve comparable results.

5.2 Evaluation metrics

Binary classification of molecular properties is evaluated by ROC-AUC scores on the query set of each test task. For a given test task, we randomly sample a support set with \(k\) examples to collect the data points for that task. Then, we evaluate the ROC-AUC scores for the model on the remainder of the data points for each test task in a disjoint query set.

In model evaluation, each task is considered independent, and we report the results for a 2-way binary classification with 5-shots and 10-shots. To show more robust results, this procedure is reported 20 times for each test task using 20 randomly sampled support sets to calculate the average ROC-AUC scores. The notation \((n_{+}, n_{-})\) indicates random support sets with \(n_{+}\) positive samples and \(n_{-}\) negative samples.

Experimental results including the mean and standard deviation of ROC-AUC scores for \((5+,5-)\) and \((10+,10-)\) random support sets are displayed in Tables 4 and 5.

5.3 Results

In this work, we systematically address the low-data problem in molecular property prediction by introducing a two-module architecture FS-GNNConv, to effectively learn deep representations from graph embeddings. In addition, we demonstrate that the proposed model outperforms different graph-based baselines.

Note that the work presented undertakes standard meta-learning practices (MAML) to iteratively adapt and generalize to new experimental tasks. Here, few-shot learning experiments model the behavior of small molecules in new experimental tasks given just a few samples of these new systems.

This section reports experimental results for few-shot models across a number of tasks on the Tox21 and SIDER datasets.

The results in Tables 4 and 5 are the average ROC-AUC scores obtained on 20 experiments with 20 different \((5+,5-)\) and \((10+,10-)\) random support sets, respectively.

Table 4 Average ROC-AUC scores for binary classification with 5-shots on benchmark datasets Tox21 and SIDER
Table 5 Average ROC-AUC scores for binary classification with 10-shots on benchmark datasets Tox21 and SIDER

The results in Table 4 confirm that the proposed model outperforms the best baseline method on Tox21 for all 3 test tasks and for 4 test tasks on SIDER. For 5-shot experiments, we observe an average overall improvement on Tox21 of \(+11.01\%\).

Table 5 reports analogous results for the 10-shot experiment. In this case, the proposed model outperforms the best baseline method on Tox21 for all 3 test tasks and for 4 test tasks on SIDER. We also observe an average overall improvement on Tox21 of \(+11.37\%\) and \(+0.53\%\) on SIDER.

A graphical representation of these results is shown in Figs. 6 and 7. For both datasets, SIDER and Tox21, there are clear differences in performance between tasks, suggesting that the model generalizes better to some tasks than others in the test data. Due to a lower amount of tasks and greater number of samples per task, the model performs better on Tox21 showing lower variances and greater ROC-AUC scores.

For the 5-shot and 10-shot experiments, the standard deviations indicate the lower variances of the proposed model when compared with the graph-based baselines. This translates into a more stable performance that provides more robust results.

It is clear that the baseline methods do not present a stable performance on a number of tasks. Simply put, they may generalize well for one task but perform poorly for most tasks.

Fig. 6
figure 6

Distribution of ROC-AUC scores of the proposed model for 20 experiments with 20 random \((5+,5-)\) support sets

Fig. 7
figure 7

Distribution of ROC-AUC scores of the proposed model for 20 experiments with 20 random \((10+,10-)\) support sets

In this scenario, most few-shot models struggle to deal with larger support set sizes and show a more robust improvement in the presence of less data. On that account, we report superior performance with \((5+,5-)\) random support sets for most baseline methods. However, the same does not apply to the proposed model. This can be explained by the convolutional component, which adds a significant boost with larger support sets \((10+,10-)\).

Since SIDER has more tasks than Tox21, it is difficult to achieve great overall performance for an extensive set of separate tasks. In contrast, due to the larger size of Tox21, the proposed model performs better by exploiting the generalization capabilities of the CNN module.

Nonetheless, we still experience some underlying limitations of few-shot models. As shown in transfer learning experiments presented in Sect. 5.4, few-shot models struggle to classify completely different tasks with little or no degree of similarity between them. Experimentation also demonstrates that few-shot methods find it difficult to generalize to unrelated tasks regardless of the direction from which we transfer the knowledge. These results indicate that there is a long path to achieve broader generalization and predict completely unrelated tasks of a disjoint system.

Finally, in Sect. 5.5, we explore t-SNE visualizations to visually compare graph embeddings and deep representations. For each dataset, we show the differences between t-SNE cluster plots obtained by the proposed model and the graph-based baselines. It has been reported that deep representations mapped to the reduced space perform better in discriminating both types of molecules (positive or negative) for each molecular property.

All documentation and code scripts to reproduce the results are available to facilitate further experimentation.

5.4 Case study: transfer learning with few-shot models

Previous experiments report the ability of few-shot learning to transfer information from one training task to rather similar testing tasks. To complement this work, we also test whether the proposed model is able to transfer knowledge from Tox21 to predict new tasks in the SIDER benchmark, and vice versa. In practice, the goal is to learn a model trained to predict the toxicity on different nuclear receptors (NR) and stress response pathways (SR) (Tox21) and use it to predict the side effects on real patients over 27 organ systems (SIDER). Conversely, we aim to predict the toxicity on Tox21 from a model trained to predict side effects on SIDER.

Consistent experimentation is conducted to evaluate whether few-shot models are able to generalize to unrelated tasks when provided with very little or no supervised information similar or closely related to the test data.

From this perspective, we assess the ability of few-shot learning to generalize by transferring knowledge between two broadly distinct data repositories. In Table 6, we report the mean ROC-AUC scores for all 27 SIDER tasks for models trained on Tox21. This experiment is repeated 20 times with 20 different \((5+,5-)\) random support sets. The reverse experiment with the same settings is reported in Table 7.

Table 6 Mean ROC-AUC scores of models trained on Tox21 to predict SIDER tasks
Table 7 Mean ROC-AUC scores of models trained on SIDER to predict Tox21 tasks

We conclude that none of the few-shot models reported achieve acceptable performance for rather distinct data collections, attesting to the lack of predictive power for unrelated tasks.

5.5 Case study: t-SNE visualization of graph embeddings and deep representations

Well-understood methods such as principal component analysis (PCA) map high-dimensional data into low-dimensional feature spaces by retaining the global structure to preserve data variance globally across the entire dataset.

t-distributed stochastic neighborhood embedding (t-SNE) works differently by observing closely located datapoints. To this end, t-SNE computes a metric to measure the distance between datapoints and a given number of neighbors and models this relation by a t-distributed distribution. Then, it tries to find an optimal embedding such that graphs in the original \(n\)-dimensional space are mapped to close locations in a low-dimensional space [54].

t-SNE works remarkably well to retain the local structure so that clusters of graph embeddings and deep representations in the reduced space are interpretable as molecules that were also similar in the high-dimensional space.

Deep representations for different molecular property tasks on the Tox21 dataset: {SR-HSE, SR-MMP, SR-p53} using t-SNE visualizations are shown in Figs. 8 and 9. In Fig. 10, we compare t-SNE visualizations of graph embeddings and deep representations. Blue dots denote negative samples for each test task. The orange dots represent positive samples.

Fig. 8
figure 8

t-SNE visualizations of deep representations \(h_{\rm{conv}}\) generated by FS-GNNConv for the Tox21 dataset for \((5+,5-)\) random support sets. The orange dots represent positive labels and the blue points the negative labels. SR-HSE, SR-MMP and SR-p53 tasks are described by plots (a), (b) and (c), respectively

Fig. 9
figure 9

t-SNE visualizations of deep representations \(h_{\rm{conv}}\) generated by FS-GNNConv for the Tox21 dataset for \((10+,10-)\) random support sets. The orange dots represent positive labels and the blue points the negative labels. SR-HSE, SR-MMP and SR-p53 tasks are described by plots (a), (b) and (c), respectively

Fig. 10
figure 10

t-SNE visualizations of graph embeddings generated by GraphSAGE, GIN, GCN and deep representations for the Tox21 SR-MMP task for \((5+,5-)\) random support sets. The orange dots represent positive labels and the blue points the negative labels

One important feature of t-SNE is the perplexity parameter. Perplexity balances the importance of local and global structure in the plotted result. For lower values, we focus on local aspects, while large perplexity values denote a global sense of geometry in high-dimensional spaces. To balance both views, we fine-tuned the perplexity parameter and fixed a value of 30.

In Figs. 8 and 9, it is noticeable that our model performs well in discriminating both types of molecules (positive or negative) since positive datapoints are found closer to each other in the reduced space. For cases (a) and (c), we can see clusters of orange dots progressively separating from blue datapoints. In case (b), most orange points are separated from the blue datapoints located in the upper region, denoting two well-defined clusters of molecules.

It is clear that our model achieves better performance when discriminating between positive and negative samples than the other baseline methods. In Fig. 10, GIN, GCN and GraphSAGE show sparsely located positive and negative samples, making it difficult to identify well-defined groups of molecules. Conversely, deep representations obtained by FS-GNNConv place positive samples on the bottom left corner to form clusters of molecules closely related to each other.

In addition, we observe local dependencies separating both positive and negative samples. In a broader view, an elongated shape is visible, which might indicate the existence of complex interaction patterns shared by deep representations expressing a global connectivity among molecules.

6 Conclusion

The main goal of this paper is to tackle the challenge of low-data in few-shot molecular property prediction. We systematically address this issue by introducing an architecture to effectively learn deep representations from graph embeddings. In this work, we demonstrate that the proposed model outperforms different graph-based methods.

Small data networks (Tox21 and SIDER) simulate an environment favorable for low-data learning where few-shot models unequivocally outperform simple deep learning approaches. Both benchmarks include high-level measurements of toxicity and side-effect frequency, making predictions volatile and highly uncertain. This behavior makes few-shot learning results particularly interesting and gives a strong indication of superior performances in small biological datasets.

In this work, we proposed a new few-shot two-module architecture, called FS-GNNConv, to address the low-data problem of molecular property prediction. A GNN module encodes the topological structure of molecular graphs as a set of node (atoms) and edge features (chemical bonds). The resulting graphs are then converted into embedded representations. By exploiting the rich information of these embedded descriptors, a CNN propagates deep representations across convolutional layers to generalize to new chemical properties and unseen classes of molecular scaffolds.

A meta-learning framework for optimizing a two-module network across tasks was developed promoting quickly adaptation to new molecular properties on few-shot data. Analysis of the experimental results demonstrated the predictive power and robustness of the proposed model over standard graph-based methods on multi-property prediction data. The results showed that FS-GNNConv takes a step forward to generalize to new experimental tasks, marginally identical to the tasks found in training.

As shown in Sect. 5.3, the proposed model outperforms the best baseline method presented for the majority of test tasks with an average overall improvement of \(+11.37\%\) and \(+0.53\%\) for Tox21 and SIDER, respectively (for \((10+,10-)\) random support sets). We posit that the novel proposed framework fully explores the potential of graph-level embeddings to generalize to new molecular properties in contrast with the other GNN competitors.

Future work includes the exploration of few-shot models to generalize to unrelated drug discovery tasks with no degree of structural similarity among molecules. It would also be a promising direction to extend the ideas to regression tasks encouraging predictions to a larger spectrum of molecular properties. We believe that this study demonstrates that starting with few-shot models as powerful non-trivial predictors can help to improve broader generalization in the molecular property prediction problem.