Skip to content

Commit d520c65

Browse files
committed
init
1 parent bf21957 commit d520c65

9 files changed

+1194
-7
lines changed

README.md

+32-7
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@ This repository contains the PyTorch code for our paper "Spatial Transcriptomics
44

55
> [paper]() | [arxiv](https://arxiv.org/pdf/2401.14772)
66
7-
**The code will come soon!**
8-
9-
107
## Introduction
118
Spatial transcriptomics (ST) captures gene expression fine-grained distinct regions (i.e., windows) of a tissue slide. Traditional supervised learning frameworks applied to model ST are constrained to predicting expression of gene types seen during training from slide image windows, failing to generalize to unseen gene types. To overcome this limitation, we propose a semantic guided network, a pioneering zero- shot gene expression prediction framework. Considering a gene type can be described by functionality and phenotype, we dynamically embed a gene type to a vector per its functionality and phenotype, and employ this vector to project slide image windows to gene expression in feature space, unleashing zero-shot expression prediction for unseen gene types. The gene type functionality and phenotype are queried with a carefully designed prompt from a pre-trained large language model. On standard benchmark datasets, we demonstrate competitive zero-shot performance compared to past state-of-the-art supervised learning approaches.
129

@@ -20,16 +17,44 @@ Spatial transcriptomics (ST) captures gene expression fine-grained distinct regi
2017
<img src="asset/model.png", width=500/>
2118
</div>
2219

23-
## Requirements
20+
## Dependency
21+
* python 3.10.13
22+
* pytorch_lightning 1.6.4
23+
* tifffile 2024.2.12
24+
* Pillow 10.2.0
25+
* scanpy 1.10.2
26+
* torch 2.2.1+cu118
2427

25-
Please refer to [requirements.txt](./requirements.txt).
28+
## Dataset
29+
* Obtain [10xgenomics dataset](https://www.10xgenomics.com/resources/datasets?query=&page=1&configure%5Bfacets%5D%5B0%5D=chemistryVersionAndThroughput&configure%5Bfacets%5D%5B1%5D=pipeline.version&configure%5BhitsPerPage%5D=500&configure%5BmaxValuesPerFacet%5D=1000).
2630

27-
## How to run
31+
## Train SGN
32+
* Change system directory
33+
```bash
34+
cd v1
35+
```
2836

37+
* Extract features and build graph
2938
```bash
30-
python main.py
39+
40+
python3 extract_feature.py --file_path Please fill # Set the file_path property to the location where the downloaded data will be stored. Remember to unzip the spatial.zip file.
41+
python3 generate_graph.py --file_path Please fill # Set the file_path property to the location where the downloaded data will be stored. Remember to unzip the spatial.zip file.
42+
python3 name_to_feature.py
43+
3144
```
3245

46+
* Gene expression prediction
47+
```bash
48+
cd ../
49+
python3 main.py # Feel free to adjust the arguments as necessary.
50+
```
51+
52+
## Contact
53+
If you have any questions, please drop [me](mailto:[email protected]?subject=[GitHub]SGN) an email.
54+
55+
56+
## Acknowledgement
57+
EVA-CLIP is built using the awesome [timm](https://github.com/huggingface/pytorch-image-models).
3358

3459
## Citation
3560

main.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import argparse
2+
import os
3+
from model import GNN
4+
import pytorch_lightning as pl
5+
from functools import partial
6+
import torch
7+
import collections
8+
from train import TrainerModel
9+
from sklearn.model_selection import KFold
10+
import glob
11+
import torch_geometric
12+
13+
14+
def load_dataset(pts,file_path):
15+
all_files = sorted(glob.glob(f"{file_path}/*.pt"))
16+
print(all_files)
17+
selected_files = []
18+
for i in all_files:
19+
for j in pts:
20+
if i.endswith(str(j) + ".pt"):
21+
graph = torch.load(i)
22+
print(graph)
23+
selected_files.append(graph)
24+
return selected_files
25+
26+
def main(args,idx):
27+
28+
29+
XFOLD = glob.glob(f"{args.file_path}/*.pt")
30+
skf = KFold(n_splits=3,shuffle= True, random_state = 12345)
31+
KFOLD = []
32+
for x in skf.split(XFOLD):
33+
KFOLD.append(x)
34+
35+
36+
cwd = os.getcwd()
37+
38+
def write(director,name,*string):
39+
string = [str(i) for i in string]
40+
string = " ".join(string)
41+
with open(os.path.join(director,name),"a") as f:
42+
f.write(string + "\n")
43+
44+
args.folder_name = "log" + "/" + str(idx)
45+
store_dir = args.folder_name + "/" + "checkpoints_" + str(args.fold) + "/"
46+
print = partial(write,cwd, args.folder_name + "/" +"log_f" + str(args.fold))
47+
48+
os.makedirs(store_dir, exist_ok= True)
49+
50+
print(args)
51+
52+
53+
train_patient, test_patient = KFOLD[args.fold]
54+
55+
train_dataset = load_dataset(train_patient,args.file_path)
56+
test_dataset = load_dataset(test_patient,args.file_path)
57+
58+
train_loader = torch_geometric.loader.DataLoader(
59+
train_dataset,
60+
batch_size=1,
61+
)
62+
63+
test_loader = torch_geometric.loader.DataLoader(
64+
test_dataset,
65+
batch_size=1,
66+
)
67+
68+
print(len(train_loader), len(test_loader))
69+
70+
model = GNN(args.hidden_channels, args.embed_dim, args.out_channels, args.gnn_layer,args.feature_dim,args.name_dim)
71+
CONFIG = collections.namedtuple('CONFIG', ['lr', 'logfun', 'verbose_step', 'weight_decay', 'store_dir'])
72+
config = CONFIG(args.lr, print, args.verbose_step, args.weight_decay,store_dir)
73+
74+
if args.checkpoints != None:
75+
model.load_state_dict(torch.load(args.checkpoints,map_location = torch.device("cpu")))
76+
77+
model = TrainerModel(config, model,args.meta, args.name_feature)
78+
79+
plt = pl.Trainer(max_epochs = args.epoch,num_nodes=1, gpus=args.gpus, val_check_interval = args.val_interval,checkpoint_callback = False, logger = False)
80+
plt.fit(model,train_dataloaders=train_loader,val_dataloaders=test_loader)
81+
82+
if __name__ == "__main__":
83+
84+
parser = argparse.ArgumentParser()
85+
parser.add_argument("--epoch", default = 300, type = int)
86+
parser.add_argument("--fold", default = 0, type = int)
87+
parser.add_argument("--gpus", default = 1, type = int)
88+
parser.add_argument("--acce", default = "ddp", type = str)
89+
parser.add_argument("--val_interval", default = 0.8, type = float)
90+
parser.add_argument("--lr", default = 1e-4*5, type = float)
91+
parser.add_argument("--verbose_step", default = 10, type = int)
92+
parser.add_argument("--weight_decay", default = 1e-4, type = float)
93+
parser.add_argument("--checkpoints", default = None, type = str)
94+
parser.add_argument("--output", default = None, type = str)
95+
parser.add_argument("--folder_name", default = "log", type = str)
96+
parser.add_argument("--runs", default = 1, type = int)
97+
parser.add_argument("--file_path", default="extracted_feature/resnet18/graph", type = str)
98+
parser.add_argument("--name_feature", default="name_feature/Intel/neural-chat-7b-v3-1", type = str)
99+
parser.add_argument("--meta", default="preprocess/", type = str)
100+
parser.add_argument("--feature_dim", default=512, type = int)
101+
parser.add_argument("--name_dim", default=4096, type = int)
102+
parser.add_argument("--hidden_channels", default=512, type = int)
103+
parser.add_argument("--embed_dim", default=256, type = int)
104+
parser.add_argument("--out_channels", default=256, type = int)
105+
parser.add_argument("--gnn_layer", default=4, type = int)
106+
107+
108+
args = parser.parse_args()
109+
for idx in range(args.runs):
110+
for fold in range(3):
111+
args.fold = fold
112+
main(args,idx)
113+

0 commit comments

Comments
 (0)