3D reconstruction

  • 3D reconstruction is an important application for spatial alignment.

  • In this case we use 3d-OT to rebuild 3D structure from multiple mouse E11.5-E16.5 embryo Stereo-seq slices data

Loading package

[1]:
from lib_3d_OT.utils import *
import scanpy as sc
import numpy as np
import pandas as pd
import torch
from lib_3d_OT.threeDrecon import *
from lib_3d_OT.plot import *
import torch.optim as optim
import warnings
warnings.filterwarnings("ignore")
R[write to console]:                    __           __
   ____ ___  _____/ /_  _______/ /_
  / __ `__ \/ ___/ / / / / ___/ __/
 / / / / / / /__/ / /_/ (__  ) /_
/_/ /_/ /_/\___/_/\__,_/____/\__/   version 6.1.1
Type 'citation("mclust")' for citing this R package in publications.

[ ]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
[3]:
set_seed(7)

loading data

[4]:
adata1=sc.read_h5ad('/home/dbj/mouse/oT/different-time/E11.5.h5ad')
adata1.obs['truth']=adata1.obs['annotation']

adata2=sc.read_h5ad('/home/dbj/mouse/oT/different-time/E12.5.h5ad')
adata2.obs['truth']=adata2.obs['annotation']

adata3=sc.read_h5ad('/home/dbj/mouse/oT/different-time/E13.5.h5ad')
adata3.obs['truth']=adata3.obs['annotation']

adata4=sc.read_h5ad('/home/dbj/mouse/oT/different-time/E14.5.h5ad')
adata4.obs['truth']=adata4.obs['annotation']

adata5=sc.read_h5ad('/home/dbj/mouse/oT/different-time/E15.5.h5ad')
adata5.obs['truth']=adata5.obs['annotation']

adata6=sc.read_h5ad('/home/dbj/mouse/oT/different-time/E16.5.h5ad')
adata6.obs['truth']=adata6.obs['annotation']

We visualize every dataset in 2D before alignment

[5]:
sc.pl.spatial(adata1, spot_size=3, color='annotation')
sc.pl.spatial(adata2, spot_size=3, color='annotation')
sc.pl.spatial(adata3, spot_size=3, color='annotation')
sc.pl.spatial(adata4, spot_size=3, color='annotation')
sc.pl.spatial(adata5, spot_size=3, color='annotation')
sc.pl.spatial(adata6, spot_size=3, color='annotation')
_images/3Dreconstruction_8_0.png
_images/3Dreconstruction_8_1.png
_images/3Dreconstruction_8_2.png
_images/3Dreconstruction_8_3.png
_images/3Dreconstruction_8_4.png
_images/3Dreconstruction_8_5.png

Building spatiotemporal developmental trajectory using 3d-OT

  • Training encoder between adjacent slices

[6]:
adatalist = [adata1, adata2, adata3, adata4, adata5, adata6]
corrected_adatas, graphs, best_models = pairwise_dpca_and_train(adatalist, join='inner', n_comps=50, neighbors=6,epochs=800,device=device)
pclouds_list = graphs
Processing pair: 0 -> 1
Epoch 800/800, Loss: 1.360324, Min Loss: 1.350767Training time for adata0 in pair 0->1: 33.29 seconds
Epoch 800/800, Loss: 1.299219, Min Loss: 1.265533Training time for adata1 in pair 0->1: 29.00 seconds
Processing pair: 1 -> 2
Epoch 800/800, Loss: 1.195852, Min Loss: 1.195245Training time for adata1 in pair 1->2: 30.43 seconds
Epoch 800/800, Loss: 1.591694, Min Loss: 1.591643Training time for adata2 in pair 1->2: 30.39 seconds
Processing pair: 2 -> 3
Epoch 800/800, Loss: 1.538934, Min Loss: 1.537077Training time for adata2 in pair 2->3: 29.82 seconds
Epoch 800/800, Loss: 1.274691, Min Loss: 1.246992Training time for adata3 in pair 2->3: 31.71 seconds
Processing pair: 3 -> 4
Epoch 800/800, Loss: 1.234406, Min Loss: 1.220684Training time for adata3 in pair 3->4: 30.51 seconds
Epoch 800/800, Loss: 1.518229, Min Loss: 1.501579Training time for adata4 in pair 3->4: 33.14 seconds
Processing pair: 4 -> 5
Epoch 800/800, Loss: 1.594596, Min Loss: 1.571328Training time for adata4 in pair 4->5: 34.24 seconds
Epoch 800/800, Loss: 1.244391, Min Loss: 1.240835Training time for adata5 in pair 4->5: 36.87 seconds
  • Align all slices

[7]:
aligned_models = pairwise_align_reverse(graphs, best_models, device=device, nb_epochs=1,simk=2,otk=500)
Aligning pair: graph1 -> graph0 (Pair 0)
Time Pair 0,total_loss: 0.1150,smooth_flow_loss: 0.0712 Target Recon Loss: 0.00010835,Div Flow Loss: 0.0437Aligning pair: graph3 -> graph2 (Pair 1)
Time Pair 0,total_loss: 0.1426,smooth_flow_loss: 0.0721 Target Recon Loss: 0.00006128,Div Flow Loss: 0.0704Aligning pair: graph5 -> graph4 (Pair 2)
Time Pair 0,total_loss: 0.1394,smooth_flow_loss: 0.0722 Target Recon Loss: 0.00015596,Div Flow Loss: 0.0670Aligning pair: graph7 -> graph6 (Pair 3)
Time Pair 0,total_loss: 0.1392,smooth_flow_loss: 0.0763 Target Recon Loss: 0.00005303,Div Flow Loss: 0.0628Aligning pair: graph9 -> graph8 (Pair 4)
Time Pair 0,total_loss: 0.1233,smooth_flow_loss: 0.0780 Target Recon Loss: 0.00010411,Div Flow Loss: 0.0452

Visualize the temporal developmental trajectory of the Heart

[8]:
all_arrow_ends= plot_all_pairs_cell_type_flow(graphs=graphs,aligned_models=aligned_models,device=device,finaltruth='Heart',selected_cell_type="Heart",
    xlim=(-0.1, 1.1),
    ylim=(-0.1, 1.1),
    height_scale=1.0,
    #save_path="/home/dbj/mouse/flow_plots/all_pairs_flow.png"
)

pair: graph1 -> graph0
pair: graph3 -> graph2
pair: graph5 -> graph4
pair: graph7 -> graph6
pair: graph9 -> graph8
_images/3Dreconstruction_14_1.png

Visualize the temporal developmental trajectory of the Liver

[9]:
all_arrow_ends= plot_all_pairs_cell_type_flow(graphs=graphs,aligned_models=aligned_models,device=device,finaltruth='Liver',selected_cell_type="Liver",
    xlim=(-0.1, 1.1),
    ylim=(-0.1, 1.1),
    height_scale=1.0,
    #save_path="/home/dbj/mouse/flow_plots/all_pairs_flow.png"
)
pair: graph1 -> graph0
pair: graph3 -> graph2
pair: graph5 -> graph4
pair: graph7 -> graph6
pair: graph9 -> graph8
_images/3Dreconstruction_16_1.png

Get spatiotemporal developmental trajectory files and sankey plot

  • Set the output paths of sankey plot and flow file

[10]:
save_path_sankey = "/home/dbj/mouse/multi_slice_sankey_plot.html"
csv_output_path = "/home/dbj/mouse/multi_slice_flow_file.csv"
[11]:
plot_multi_slice_sankey_with_same_label_priority(
    graphs=graphs,
    aligned_models=aligned_models,
    device=device,
    min_flow_threshold=80,
    save_path=save_path_sankey,
    csv_output_path=csv_output_path
)
 54939 pair flow
folw file save: /home/dbj/mouse/multi_slice_flow_file.csv
filter: 194  > 80)
sankeplot save: /home/dbj/mouse/multi_slice_sankey_plot.html

The flow file

[15]:
flow_file=pd.read_csv('/home/dbj/mouse/multi_slice_flow_file.csv')
[16]:
flow_file.head(5)
[16]:
Start_Label End_Label Start_X Start_Y Start_Z End_X End_Y End_Z
0 slice0_Brain slice1_Brain 0.527350 0.356488 0.0 0.452229 0.309496 1.0
1 slice0_Brain slice1_Brain 0.300422 0.138665 0.0 0.268525 0.127301 1.0
2 slice0_Brain slice1_Brain 0.633786 0.195020 0.0 0.693628 0.247989 1.0
3 slice0_Brain slice1_Brain 0.268634 0.199662 0.0 0.198457 0.210574 1.0
4 slice0_Brain slice1_Brain 0.271394 0.182702 0.0 0.188454 0.235523 1.0