Introduction¶
We saw in analysis-01-dynamics
that we could effectively separate the corn plants from the NPC. Here we want to show something fairly obvious: that we can replicate that work using soft thresholding, as opposed to hard thresholding. This shouldn't be too surprising as a sigmoid function approaches a step function as its slope increases.
We break this up into two parts. First, in part "A", we replicate the work from analysis-01-dynamics
using soft thresholding. There is no inference here, but we do make use of functions from the "rtphenos" package. Second, in part "B", we show that we can infer the parameters of interest for the same model. This can be computationally intensive, so we have split it off as a separate notebook that can be run in, e.g. Google Colab.
In inference-01-dynamics-01-univariate
, we built some intuition about how to build a good, basic model, and we use some of that intuition here, so you might want to review that notebook.
Setup¶
Import stuff and determine device we will use for computations
import os
import copy
import h5py
import pandas as pd
import numpy as np
import progressbar
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from matplotlib import pyplot as plt
import plotnine as gg
import importlib as il
il.import_module("rtphenos")
<module 'rtphenos' from '/home/jesse/rtphenos/src/rtphenos/__init__.py'>
from rtphenos.detects1 import transforms as transforms1
%matplotlib inline
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
device(type='cpu')
Data to use¶
Get RootTracker deployment data.
DATA_DIR = "/usr/local/share/rtphenos-workspace"
rt_deps = pd.read_csv(os.path.join(DATA_DIR, "deployments", "rt_acc1_deps.csv"))
rt_deps.head()
deployment_id | roottracker_serial | ram_serial | x | y | bench | crop | genotype | order | pot_on_bench | pot_on_row | rep | row | deployment_name | NPC | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 54 | EG00043 | D76080D25B7EA126 | 1.0 | 11.0 | 2 | Soy | PI416937 | 11 | 2 | 11 | 1 | 1 | ACC Test 1 | False |
1 | 54 | EG00044 | D76080D25B3BD828 | 1.0 | 3.0 | 1 | Wheat | Comet | 3 | 3 | 3 | 1 | 1 | ACC Test 1 | False |
2 | 54 | EG00045 | D76080D25B7F7526 | 4.0 | 7.0 | 1 | Cotton | PI513384 | 115 | 7 | 7 | 1 | 4 | ACC Test 1 | False |
3 | 54 | EG00046 | D76080D25B3C9628 | 2.0 | 11.0 | 2 | Soy | PI416937 | 47 | 2 | 11 | 1 | 2 | ACC Test 1 | False |
4 | 54 | EG00047 | D76080D25B3A6C28 | 3.0 | 8.0 | 1 | Cotton | PI513384 | 80 | 8 | 8 | 1 | 3 | ACC Test 1 | False |
We want to limit our attention to just corn and NPC devices, since we believe that will have the clearest separation. We also want to remove devices with odd data.
The first definition of bad_roottrackers
corresponds to those with odd data coming in on paddles 2 and 3. The second corresponds that and the top 8 devices in terms of integrated, mean-velocity-below-a-threshold, which is balanced between corn and NPC. The filtering we do in our previous analysis is critical for eliminating these NPCs. Here we can do something similar using the RemoveSimultaneous
module defined below.
bad_roottrackers = ['EG00237', 'EG00226']
# bad_roottrackers = ['EG00237', 'EG00226', 'EG00228', 'EG00062', 'EG00157', 'EG00096', 'EG00160', 'EG00103', 'EG00159', 'EG00140']
rt_deps_to_use = rt_deps.query(f"(crop == 'Corn' | crop == 'NPC') & ~(roottracker_serial in {bad_roottrackers})").reset_index(drop=True).reset_index(drop=False)
rt_deps_to_use['rt_serial'] = rt_deps_to_use['roottracker_serial']
rt_deps_to_use.shape
(46, 17)
Load data¶
We use the DataSet class from PyTorch to access our data.
class Dynamics(Dataset):
def __init__(self, rt_deps, base_dir, start=2000, end=4000, slice=[3,5]):
self.iter_idx = 0
if isinstance(rt_deps, str):
self.rt_deps = pd.read_csv(deps_file)
elif isinstance(rt_deps, pd.DataFrame):
self.rt_deps = rt_deps.reset_index()
else:
raise ValueError("rt_deps must be a path to a csv file or a Dataframe")
self.base_dir = base_dir
self.key = 'dyn'
self.cache = [None for i in range(len(self.rt_deps))]
self.start = start
self.end = end
self.slice = slice
def __len__(self):
return self.rt_deps.shape[0]
def get_path(self, idx):
rt_serial = self.rt_deps['roottracker_serial'][idx]
full_path = os.path.join(self.base_dir, self.rt_deps['deployment_name'][idx], f"{rt_serial}.h5")
if not os.path.exists(full_path):
raise Exception(f"Path {full_path} does not exist")
return full_path
def __getitem__(self, idx):
out = None
if self.cache[idx] is None:
with h5py.File(self.get_path(idx), 'r') as h5r:
self.cache[idx] = torch.tensor(
np.array(
h5r[self.key][self.start:self.end,:,:,:]
)[:,:,:,self.slice]
).float()
return self.cache[idx], torch.tensor(~self.rt_deps['NPC'][idx], dtype=torch.float32)
def lookup(self, rt_serial):
out = self.rt_deps.query(f"roottracker_serial == '{rt_serial}'")
nout = len(out)
if nout == 1:
return out.index[0]
elif nout > 1:
raise Exception("Found multiple")
else:
raise Exception("Found none")
def __iter__(self):
self.iter_idx = 0
return self
def __next__(self):
if self.iter_idx < len(self):
out = self[self.iter_idx]
self.iter_idx += 1
return out
raise StopIteration
# Should be balanced
rt_deps_to_use['crop'].value_counts()
crop Corn 23 NPC 23 Name: count, dtype: int64
Data Loaders and data¶
Index 3
is mean velocity, while index 5
is standard deviation of velocity, both for the second cooordinate. We use both here.
BATCH_SIZE=rt_deps_to_use.shape[0]
BATCH_SIZE
46
# For our work below, it is critical to use all of the data
dyn_data = Dynamics(rt_deps_to_use, os.path.join(DATA_DIR, "dynamics"), start=0, end=8043, slice=[3,5])
# dyn_data[0][0][:,0,0]
dyn_loader = DataLoader(dyn_data, batch_size=BATCH_SIZE, shuffle=False)
y_all = torch.stack([d[1] for d in dyn_data])
X_all = torch.stack([d[0] for d in dyn_data])
rts = rt_deps_to_use['roottracker_serial']
# X_all[0,:,0,0,:]
Models¶
Here we build the pieces that make up the models.
Modules¶
The soft threshold used here is not just a sigmoid function. Letting $\sigma$ be a sigmoid function and $r$ be the ReLU function, we use $$2 * r (\sigma(x) - 0.5).$$
In that case we still have a "hard" threshold at 0, but then it is soft beyond that.
class ShiftThenScale(nn.Module):
def __init__(self, scale_mag=1.):
super().__init__()
self.shift = nn.Parameter(torch.randn(()))
self.scale = nn.Parameter(torch.randn(()))
self.scale_mag = scale_mag
def forward(self, x):
return self.scale_mag * self.scale * (x - self.shift)
def __repr__(self):
return f"ShiftThenScale({self.scale_mag})"
class SoftThreshold(nn.Module):
def __init__(self):
super().__init__()
self.L1 = ShiftThenScale()
self.L2 = ShiftThenScale()
self.L3 = ShiftThenScale()
self.S = nn.Sigmoid()
self.R = nn.ReLU()
def forward(self, x):
last_dim = len(x.size()) - 1
shp = list(x.shape)
shp[last_dim] = 3
y = torch.zeros(shp)
y[:,:,:,:,0] = self.L1(x[:,:,:,:,0])
x1log = torch.log(x[:,:,:,:,1] + 1e-10)
y[:,:,:,:,1] = self.L2(x1log)
y[:,:,:,:,2] = self.L3(x1log)
y = self.S(y)
y = 2.0 * (self.R(y - 0.5))
z = torch.prod(y, dim=last_dim, keepdim=True)
z = z.squeeze(last_dim)
return z
class Integrate(nn.Module):
def __init__(self, sum_dim):
super().__init__()
self.sum_dim = sum_dim
def forward(self, x):
n = len(x.size())
x = torch.mean(x, dim=self.sum_dim, keepdim=True)
return x
def __repr__(self):
return (f"Integrate(sum_dim={self.sum_dim}")
class RemoveSimultaneous(nn.Module):
def __init__(self):
super().__init__()
self.R = nn.ReLU()
def forward(self, x):
shp = x.size()
npads = shp[2]
y = x.new(shp)
for i in range(npads):
others = [True] * npads
others[i] = False
val1, _ = torch.max(x[:,:,others,:], dim=2) # Max along other paddles
val2, _ = torch.max(val1, dim=2, keepdim=True) # Max along electrodes
# y[:,:,i,:] = x[:,:,i,:] - val1
y[:,:,i,:] = x[:,:,i,:] - val2
y = self.R(y)
return y
Model definitions (wrapped as function calls)¶
def make_model_1():
out = nn.Sequential(
SoftThreshold(),
RemoveSimultaneous(), # This makes a difference
Integrate(sum_dim = (1, 2, 3)),
ShiftThenScale(1.0),
nn.Flatten(start_dim=0)
)
return out
Model 1¶
Here we define the parameters of the model using those from analysis-01-dynamics
. From the way we constructed the model, using a hard threshold on the left, we should get nearly identical results.
model_1 = make_model_1()
# Large negative momentum_2 indicative of corn
scale1 = 100.
scale2 = 200.
model_1_defaults = {
# Momentum less than -0.25
'0.L1.scale': torch.tensor(-scale1),
'0.L1.shift': torch.tensor(-0.25),
# '0.L1.shift': torch.tensor(0.07), # This results in a LOT of detections and not very good detections
'0.L2.scale': torch.tensor(-scale1),
'0.L2.shift': torch.tensor(-1.0),
'0.L3.scale': torch.tensor(scale1),
'0.L3.shift': torch.tensor(-2.0),
'3.shift': torch.tensor(0.0),
'3.scale': torch.tensor(scale2)
}
model_1.load_state_dict(model_1_defaults)
<All keys matched successfully>
with torch.no_grad():
yhat_init_1 = model_1(X_all)
print(yhat_init_1.size())
print(yhat_init_1)
torch.Size([46]) tensor([0.0737, 0.0110, 0.0921, 0.0988, 0.0923, 0.0682, 0.0085, 0.1100, 0.0774, 0.0082, 0.1069, 0.0649, 0.1049, 0.0260, 0.1069, 0.0184, 0.0040, 0.0196, 0.0066, 0.0882, 0.1411, 0.1127, 0.0125, 0.0037, 0.0025, 0.0687, 0.1418, 0.0438, 0.1007, 0.1254, 0.0867, 0.0112, 0.0117, 0.0606, 0.1672, 0.0356, 0.0133, 0.1161, 0.0171, 0.0091, 0.0203, 0.0831, 0.1057, 0.0135, 0.1393, 0.0792])
If you do not include RemoveSimultaneous
in the model, we also see 4 NPC have high "activations" before filtering as in analysis-01-dynamics
.
with torch.no_grad():
plt.scatter(yhat_init_1.numpy(), y_all.numpy())
plt.title("Ground truth vs. fraction of anomalous time (M1)")
plt.xlabel("fraction of anomalous time")
plt.ylabel("Plant / No plant (1 or 0)")
model_1_trunc = nn.Sequential(
SoftThreshold(),
RemoveSimultaneous()
)
model_1.state_dict()
OrderedDict([('0.L1.shift', tensor(-0.2500)), ('0.L1.scale', tensor(-100.)), ('0.L2.shift', tensor(-1.)), ('0.L2.scale', tensor(-100.)), ('0.L3.shift', tensor(-2.)), ('0.L3.scale', tensor(100.)), ('3.shift', tensor(0.)), ('3.scale', tensor(200.))])
for idx in ['0', '1']:
model_1_trunc._modules[idx].load_state_dict(model_1._modules[idx].state_dict())
# print(model_1_trunc._modules[idx].state_dict())
Z = model_1_trunc(X_all)
Filter¶
And now we do the thresholding as in analysis-01-dynamics
.
temp_df_list = list()
for i, z in enumerate(Z):
temp_df = transforms1.detection_trajectories(z > 0)
temp_df['weight'] = 1.0
temp_df['rt_serial'] = rt_deps_to_use['rt_serial'][i]
temp_df['crop'] = rt_deps_to_use['crop'][i]
temp_df_list.append(temp_df)
temp_df.head()
start | end | paddle | electrode | weight | rt_serial | crop | |
---|---|---|---|---|---|---|---|
0 | 3905 | 3906 | 0 | 8 | 1.0 | EG00240 | Corn |
1 | 3907 | 3913 | 0 | 8 | 1.0 | EG00240 | Corn |
2 | 5127 | 5128 | 0 | 8 | 1.0 | EG00240 | Corn |
3 | 3896 | 3902 | 0 | 10 | 1.0 | EG00240 | Corn |
4 | 4054 | 4057 | 0 | 10 | 1.0 | EG00240 | Corn |
traj_df_1 = pd.concat(temp_df_list)
traj_df_1.shape
(12849, 7)
traj_df_2 = transforms1.consolidate_trajectories(traj_df_1.query("electrode > 0"), grouping=['crop', 'rt_serial', 'paddle', 'electrode'], cush=6)
traj_df_2.shape
(7576, 8)
traj_df_3a = transforms1.flag_overlapping_trajectories(traj_df_2.sort_values(['crop', 'rt_serial', 'start', 'end']),
grouping=['crop', 'rt_serial', 'paddle'], cush=6)
# traj_df_3a.head(n=10)
traj_df_3 = traj_df_3a.groupby(['crop', 'rt_serial', 'paddle', 'group_num']).agg({
'start': 'min',
'end': 'max',
'electrode': 'min'
}).reset_index(drop=False)
traj_df_3.shape
(4820, 7)
# traj_df_3.head()
# Remove trajectories that overlap in time and on adjascent paddles (we do not use modular arithmatic here, so this is only approximate).
traj_df_4a = transforms1.flag_overlapping_trajectories(traj_df_3.sort_values(['crop', 'rt_serial', 'start', 'end']), grouping=['crop', 'rt_serial'], cush=6)
traj_df_4a['num_in_group'] = traj_df_4a.groupby(['group_num'])['group_num'].transform(lambda x: len(x))
traj_df_4a['length'] = traj_df_4a['end'] - traj_df_4a['start']
traj_df_4a.shape
(4820, 10)
traj_df_4a.head()
crop | rt_serial | paddle | group_num | start | end | electrode | overlap_prev | num_in_group | length | |
---|---|---|---|---|---|---|---|---|---|---|
0 | Corn | EG00054 | 10 | 1 | 36 | 37 | 20 | False | 1 | 1 |
1 | Corn | EG00054 | 10 | 2 | 346 | 351 | 20 | False | 1 | 5 |
2 | Corn | EG00054 | 10 | 3 | 378 | 396 | 20 | False | 1 | 18 |
3 | Corn | EG00054 | 2 | 4 | 532 | 533 | 12 | False | 6 | 1 |
4 | Corn | EG00054 | 10 | 4 | 534 | 550 | 8 | True | 6 | 16 |
There are a lot of trajectories that occur on just 1 or 2 paddles. If you just keep the number of trajectories that have only one active paddle, then you filter 2/3 of your detections! Whether we use 1 or 1 and 2, we still get reasonable separation of NPC and Corn below. We could also do this after the fact by modeling.
traj_df_4a['num_in_group'].value_counts()
num_in_group 1 1837 2 976 3 621 4 444 5 300 6 276 7 168 8 88 9 45 11 33 10 20 12 12 Name: count, dtype: int64
pd.crosstab(traj_df_4a['num_in_group'], traj_df_4a['crop'])
crop | Corn | NPC |
---|---|---|
num_in_group | ||
1 | 1462 | 375 |
2 | 764 | 212 |
3 | 468 | 153 |
4 | 312 | 132 |
5 | 230 | 70 |
6 | 192 | 84 |
7 | 112 | 56 |
8 | 64 | 24 |
9 | 27 | 18 |
10 | 20 | 0 |
11 | 11 | 22 |
12 | 12 | 0 |
# Lots of overlap in this NPC, but it doesn't all align
# traj_df_4a.query("rt_serial == 'EG00103' & num_in_group == 6")
# We remove 2/3 of trajectories by removing things at occur at the same time!!!
traj_df_4 = traj_df_4a.query("num_in_group <= 1")
traj_df_4.shape
(1837, 10)
# traj_df_4.head()
traj_df_5 = traj_df_4.copy()
# traj_df_5 = traj_df_4.merge(crop_df, on='rt_serial')
traj_df_5.shape
(1837, 10)
# traj_df_5['length'] = traj_df_5['end'] - traj_df_5['start']
traj_df_5['day']= np.floor(traj_df_5['start'] / (12 * 24)).astype(int)
traj_df_5.head()
crop | rt_serial | paddle | group_num | start | end | electrode | overlap_prev | num_in_group | length | day | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | Corn | EG00054 | 10 | 1 | 36 | 37 | 20 | False | 1 | 1 | 0 |
1 | Corn | EG00054 | 10 | 2 | 346 | 351 | 20 | False | 1 | 5 | 1 |
2 | Corn | EG00054 | 10 | 3 | 378 | 396 | 20 | False | 1 | 18 | 1 |
9 | Corn | EG00054 | 11 | 5 | 641 | 668 | 8 | False | 1 | 27 | 2 |
15 | Corn | EG00054 | 0 | 8 | 972 | 998 | 7 | False | 1 | 26 | 3 |
Plot¶
Let's look at the same plots we created previously. They all look very similar.
gg.ggplot(traj_df_5) + gg.geom_histogram(gg.aes(x='length', fill='crop'), bins=20)
<Figure Size: (640 x 480)>
# We integrate here, which can be critical to getting a difference if you do not filter out short-length anomalies.
traj_df_6 = traj_df_5.query("length > 6").groupby(['crop', 'rt_serial']).agg({'length': 'sum', 'start': 'size'}).reset_index(drop=False)
traj_df_6['count'] = traj_df_6.pop('start')
traj_df_6.shape
(46, 4)
# traj_df_6.head()
Looking at differences in cumulative time:
gg.ggplot(traj_df_6) + gg.geom_histogram(gg.aes(x='length', fill='crop'), bins=20) + \
gg.ggtitle("Distribution of cumulative anomalous trajectory time by crop")
<Figure Size: (640 x 480)>
Looking at differences in counts:
gg.ggplot(traj_df_6) + gg.geom_histogram(gg.aes(x='count', fill='crop'), bins=20) + \
gg.ggtitle("Distribution of number of anomalous trajectories by crop")
<Figure Size: (640 x 480)>
gg.ggplot(traj_df_5.query("length > 3")) + gg.geom_point(gg.aes('start', 'rt_serial', fill='crop')) + \
gg.ggtitle("Detection times by device")
<Figure Size: (640 x 480)>
traj_df_day = traj_df_5.groupby(['crop', 'rt_serial', 'day']).agg({'length': 'sum', 'start': 'size', 'electrode': 'mean'}).reset_index(drop=False)
traj_df_day['day'] = pd.Categorical(traj_df_day['day'], ordered=True)
# traj_df_day.head()
gg.ggplot(traj_df_day) + gg.geom_boxplot(gg.aes('day', 'start')) + gg.facet_wrap("~ crop") + \
gg.ggtitle("Distribution of number of detections by day")
<Figure Size: (640 x 480)>
gg.ggplot(traj_df_day) + gg.geom_boxplot(gg.aes('day', 'electrode')) + gg.facet_wrap("~ crop") + \
gg.ggtitle("Distribution of number of detections by day")
<Figure Size: (640 x 480)>
Conclusion (Part A)¶
Above, we showed that the soft thresholded version of our work from analysis-01-dynamics
does indeed produce very similar results. In other words, we can construct a model (a neural network) that recapitulates the ad hoc approach that worked well previously. In part B, we show that we can learn the optimal parameters of this model from data. (See inference-01-dynamics-bivariate-B
.)