|
2 | 2 |
|
3 | 3 | --- |
4 | 4 |
|
5 | | -### Diagram for data processing in molGPS. |
6 | | - |
7 | | -<img src="images/datamodule.png" alt= "Data Processing Chart" width="100%" height="100%"> |
8 | | - |
9 | | - |
10 | | - |
11 | | -### Diagram for Muti-task network in molGPS |
12 | | - |
13 | | -<img src="images/full_graph_network.png" alt= "Full Graph Multi-task Network" width="100%" height="100%"> |
14 | 5 |
|
| 6 | +The library is designed with 3 things in mind: |
15 | 7 |
|
| 8 | +- High modularity and configurability with *YAML* files |
| 9 | +- Contain the state-of-the art GNNs, including positional encodings and graph Transformers |
| 10 | +- Massively multitasking across diverse and sparse datasets |
16 | 11 |
|
| 12 | +The current page will walk you through the different aspects of the design that enable that. |
17 | 13 |
|
| 14 | +### Diagram for data processing in Graphium. |
18 | 15 |
|
| 16 | +First, when working with molecules, there are tons of options regarding atomic and bond featurisation that can be extracted from the periodic table, from empirical results, or from simulated 3D structures. |
19 | 17 |
|
20 | | -**Section from the previous README:** |
| 18 | +Second, when working with graph Transformers, there are plenty of options regarding the positional and structural encodings (PSE) that are fundamental in driving the accuracy and the generalization of the models. |
21 | 19 |
|
22 | | -### Data setup |
| 20 | +With this in mind, we propose a very versatile chemical and PSE encoding, alongside an encoder manager, that can be fully configured from the yaml files. The idea is to assign matching *input keys* to both the features and the encoders, then pool the outputs according to the *output keys*. It is better summarized in the image below. |
23 | 21 |
|
24 | | -Then, you need to download the data needed to run the code. Right now, we have 2 sets of data folders, present in the link [here](https://drive.google.com/drive/folders/1RrbNZkEE2rf41_iroa1LbIyegW00h3Ql?usp=sharing). |
| 22 | +<img src="images/datamodule.png" alt= "Data Processing Chart" width="100%" height="100%"> |
25 | 23 |
|
26 | | -- **micro_ZINC** (Synthetic dataset) |
27 | | - - A small subset (1000 mols) of the ZINC dataset |
28 | | - - The score is the subtraction of the computed LogP and the synthetic accessibility score SA |
29 | | - - The data must be downloaded to the folder `./graphium/data/micro_ZINC/` |
30 | 24 |
|
31 | | -- **ZINC_bench_gnn** (Synthetic dataset) |
32 | | - - A subset (12000 mols) of the ZINC dataset |
33 | | - - The score is the subtraction of the computed LogP and the synthetic accessibility score SA |
34 | | - - These are the same 12k molecules provided by the [Benchmarking-gnn](https://github.com/graphdeeplearning/benchmarking-gnns) repository. |
35 | | - - We provide the pre-processed graphs in `ZINC_bench_gnn/data_from_benchmark` |
36 | | - - We provide the SMILES in `ZINC_bench_gnn/smiles_score.csv`, with the train-val-test indexes in the file `indexes_train_val_test.csv`. |
37 | | - - The first 10k elements are the training set |
38 | | - - The next 1k the valid set |
39 | | - - The last 1k the test set. |
40 | | - - The data must be downloaded to the folder `./graphium/data/ZINC_bench_gnn/` |
41 | 25 |
|
42 | | -Then, you can run the main file to make sure that all the dependancies are correctly installed and that the code works as expected. |
| 26 | +### Diagram for Muti-task network in Graphium |
43 | 27 |
|
44 | | -```bash |
45 | | -python expts/main_micro_zinc.py |
46 | | -``` |
| 28 | +As mentioned, we want to be able to pperform massive multi-tasking to enable us to work across a huge diversity of datasets. The idea is to use a combination of multiple task-heads, where a different MLP is applied to each task predictions. However, it is also designed such that each task can have as many labels as desired, thus enabling to group labels together according to whether they should share weights/losses. |
47 | 29 |
|
48 | | ---- |
| 30 | +The design is better explained in the image below. Notice how the *keys* outputed by GraphDict are used differently across the different GNN layers. |
49 | 31 |
|
50 | | -**TODO: explain the internal design of Graphium so people can contribute to it more easily.** |
| 32 | +<img src="images/full_graph_network.png" alt= "Full Graph Multi-task Network" width="100%" height="100%"> |
51 | 33 |
|
52 | 34 | ## Structure of the code |
53 | 35 |
|
54 | 36 | The code is built to rapidly iterate on different architectures of neural networks (NN) and graph neural networks (GNN) with Pytorch. The main focus of this work is molecular tasks, and we use the package `rdkit` to transform molecular SMILES into graphs. |
55 | 37 |
|
56 | | -### data_parser |
57 | | - |
58 | | -This folder contains tools that allow tdependenciesrent kind of molecular data files, such as `.csv` or `.xlsx` with SMILES data, or `.sdf` files with 3D data. |
59 | | - |
60 | | - |
61 | | -### features |
62 | | - |
63 | | -Different utilities for molecules, such as Smiles to adjacency graph transformer, molecular property extraction, atomic properties, bond properties, ... |
64 | | - |
65 | | -**_The MolecularTransformer and AdjGraphTransformer come from ivbase, but I don't like them. I think we should replace them with something simpler and give more flexibility for combining one-hot embedding with physical properties embedding._**. |
66 | | - |
67 | | -### trainer |
68 | | - |
69 | | -The trainer contains the interface to the `pytorch-lightning` library, with `PredictorModule` being the main class used for any NN model, either for regression or classification. It also contains some modifications to the logger from `pytorch-lightning` to enable more flexibility. |
70 | | - |
71 | | -### utils |
72 | | - |
73 | | -Any kind of utilities that can be used anywhere, including argument checkers and configuration loader |
74 | | - |
75 | | -### visualization |
76 | | - |
77 | | -Plot visualization tools |
78 | | - |
79 | | -## Modifying the code |
80 | | - |
81 | | -### Adding a new GNN layer |
82 | | - |
83 | | -Any new GNN layer must inherit from the class `graphium.nn.base_graph_layer.BaseGraphLayer` and be implemented in the folder `graphium/nn/pyg_layers`, imported in the file `graphium/nn/architectures.py`, and in the same file, added to the function `FeedForwardGraph._parse_gnn_layer`. |
84 | | - |
85 | | -To be used in the configuration file as a `graphium.model.layer_name`, it must also be implemented with some variable parameters in the file `expts/config_gnns.yaml`. |
| 38 | +Below are a list of directory and their respective documentations: |
86 | 39 |
|
87 | | -### Adding a new NN architecture |
| 40 | +- cli |
| 41 | +- [config](https://github.com/datamol-io/graphium/blob/main/graphium/config/README.md) |
| 42 | +- [data](https://github.com/datamol-io/graphium/blob/main/graphium/data/README.md) |
| 43 | +- [features](https://github.com/datamol-io/graphium/tree/main/graphium/features/README.md) |
| 44 | +- finetuning |
| 45 | +- [ipu](https://github.com/datamol-io/graphium/tree/main/graphium/ipu/README.md) |
| 46 | +- [nn](https://github.com/datamol-io/graphium/tree/main/graphium/nn/README.md) |
| 47 | +- [trainer](https://github.com/datamol-io/graphium/tree/main/graphium/trainer/README.md) |
| 48 | +- [utils](https://github.com/datamol-io/graphium/tree/main/graphium/features/README.md) |
| 49 | +- [visualization](https://github.com/datamol-io/graphium/tree/main/graphium/visualization/README.md) |
88 | 50 |
|
89 | | -All NN and GNN architectures compatible with the `pyg` library are provided in the file `graphium/nn/global_architectures.py`. When implementing a new architecture, it is highly recommended to inherit from `graphium.nn.architectures.FeedForwardNN` for regular neural networks, from `graphium.nn.global_architectures.FeedForwardGraph` for pyg neural network, or from any of their sub-classes. |
90 | 51 |
|
91 | | -### Changing the PredictorModule and loss function |
| 52 | +## Structure of the configs |
92 | 53 |
|
93 | | -The `PredictorModule` is a general pytorch-lightning module that should work with any kind of `pytorch.nn.Module` or `pl.LightningModule`. The class defines a structure of including models, loss functions, batch sizes, collate functions, metrics... |
| 54 | +Making the library very modular requires to have configuration files that have >200 lines, which becomes intractable, especially when we only want to have minor changes between configurations. |
94 | 55 |
|
95 | | -Some loss functions are already implemented in the PredictorModule, including `mse, bce, mae, cosine`, but some tasks will require more complex loss functions. One can add any new function in `graphium.trainer.predictor.PredictorModule._parse_loss_fun`. |
| 56 | +Hence, we use [hydra](https://hydra.cc/docs/intro/) to enable splitting the configurations into smaller and composable configuration files. |
96 | 57 |
|
97 | | -### Changing the metrics used |
| 58 | +Examples of possibilities include: |
98 | 59 |
|
99 | | -**_!WARNING! The metrics implementation was done for pytorch-lightning v0.8. There has been major changes to how the metrics are used and defined, so the whole implementation must change._** |
| 60 | +- Switching between accelerators (CPU, GPU and IPU) |
| 61 | +- Benchmarking different models on the same dataset |
| 62 | +- Fine-tuning a pre-trained model on a new dataset |
100 | 63 |
|
101 | | -Our current code is compatible with the metrics defined by _pytorch-lightning_, which include a great set of metrics. We also added the PearsonR and SpearmanR as they are important correlation metrics. You can define any new metric in the file `graphium/trainer/metrics.py`. The metric must inherit from `TensorMetric` and must be added to the dictionary `graphium.trainer.metrics.METRICS_DICT`. |
| 64 | +[In this document](https://github.com/datamol-io/graphium/tree/main/expts/hydra-configs#readme), we describe in details how each of the above functionality is achieved and how users can benefit from this design to achieve the most with Graphium with as little configuration as possible. |
102 | 65 |
|
103 | | -To use the metric, you can easily add it's name from `METRICS_DICT` in the yaml configuration file, at the address `metrics.metrics_dict`. Each metric has an underlying dictionnary with a mandatory `threshold` key containing information on how to threshold the prediction/target before computing the metric. Any `kwargs` arguments of the metric must also be added. |
0 commit comments