Introduction¶

We saw in analysis-01-dynamics that we could effectively separate the corn plants from the NPC. Here we want to show something fairly obvious: that we can replicate that work using soft thresholding, as opposed to hard thresholding. This shouldn't be too surprising as a sigmoid function approaches a step function as its slope increases.

We break this up into two parts. First, in part "A", we replicate the work from analysis-01-dynamics using soft thresholding. There is no inference here, but we do make use of functions from the "rtphenos" package. Second, in part "B", we show that we can infer the parameters of interest for the same model. This can be computationally intensive, so we have split it off as a separate notebook that can be run in, e.g. Google Colab.

In inference-01-dynamics-01-univariate, we built some intuition about how to build a good, basic model, and we use some of that intuition here, so you might want to review that notebook.

Setup¶

Import stuff and determine device we will use for computations

In [1]:
import os
import copy
import h5py
import pandas as pd
import numpy as np
import progressbar
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from matplotlib import pyplot as plt
import plotnine as gg
import importlib as il
In [2]:
il.import_module("rtphenos")
Out[2]:
<module 'rtphenos' from '/home/jesse/rtphenos/src/rtphenos/__init__.py'>
In [3]:
from rtphenos.detects1 import transforms as transforms1
In [4]:
%matplotlib inline
In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
Out[5]:
device(type='cpu')

Data to use¶

Get RootTracker deployment data.

In [6]:
DATA_DIR = "/usr/local/share/rtphenos-workspace"
In [7]:
rt_deps = pd.read_csv(os.path.join(DATA_DIR, "deployments", "rt_acc1_deps.csv"))
rt_deps.head()
Out[7]:
deployment_id roottracker_serial ram_serial x y bench crop genotype order pot_on_bench pot_on_row rep row deployment_name NPC
0 54 EG00043 D76080D25B7EA126 1.0 11.0 2 Soy PI416937 11 2 11 1 1 ACC Test 1 False
1 54 EG00044 D76080D25B3BD828 1.0 3.0 1 Wheat Comet 3 3 3 1 1 ACC Test 1 False
2 54 EG00045 D76080D25B7F7526 4.0 7.0 1 Cotton PI513384 115 7 7 1 4 ACC Test 1 False
3 54 EG00046 D76080D25B3C9628 2.0 11.0 2 Soy PI416937 47 2 11 1 2 ACC Test 1 False
4 54 EG00047 D76080D25B3A6C28 3.0 8.0 1 Cotton PI513384 80 8 8 1 3 ACC Test 1 False

We want to limit our attention to just corn and NPC devices, since we believe that will have the clearest separation. We also want to remove devices with odd data.

The first definition of bad_roottrackers corresponds to those with odd data coming in on paddles 2 and 3. The second corresponds that and the top 8 devices in terms of integrated, mean-velocity-below-a-threshold, which is balanced between corn and NPC. The filtering we do in our previous analysis is critical for eliminating these NPCs. Here we can do something similar using the RemoveSimultaneous module defined below.

In [8]:
bad_roottrackers = ['EG00237', 'EG00226']
# bad_roottrackers = ['EG00237', 'EG00226', 'EG00228', 'EG00062', 'EG00157', 'EG00096', 'EG00160', 'EG00103', 'EG00159', 'EG00140']
In [9]:
rt_deps_to_use = rt_deps.query(f"(crop == 'Corn' | crop == 'NPC') & ~(roottracker_serial in {bad_roottrackers})").reset_index(drop=True).reset_index(drop=False)
rt_deps_to_use['rt_serial'] = rt_deps_to_use['roottracker_serial']
rt_deps_to_use.shape
Out[9]:
(46, 17)

Load data¶

We use the DataSet class from PyTorch to access our data.

In [10]:
class Dynamics(Dataset):

    def __init__(self, rt_deps, base_dir, start=2000, end=4000, slice=[3,5]):
        self.iter_idx = 0
        if isinstance(rt_deps, str):
            self.rt_deps = pd.read_csv(deps_file)
        elif isinstance(rt_deps, pd.DataFrame):
            self.rt_deps = rt_deps.reset_index()
        else:
            raise ValueError("rt_deps must be a path to a csv file or a Dataframe")
        self.base_dir = base_dir
        self.key = 'dyn'
        self.cache = [None for i in range(len(self.rt_deps))]
        self.start = start
        self.end = end
        self.slice = slice

    def __len__(self):
        return self.rt_deps.shape[0]

    def get_path(self, idx):
        rt_serial = self.rt_deps['roottracker_serial'][idx]
        full_path = os.path.join(self.base_dir, self.rt_deps['deployment_name'][idx], f"{rt_serial}.h5")
        if not os.path.exists(full_path):
            raise Exception(f"Path {full_path} does not exist")
        return full_path

    def __getitem__(self, idx):
        out = None
        if self.cache[idx] is None:
            with h5py.File(self.get_path(idx), 'r') as h5r:
                self.cache[idx] = torch.tensor(
                    np.array(
                        h5r[self.key][self.start:self.end,:,:,:]
                    )[:,:,:,self.slice]
                ).float()
        return self.cache[idx], torch.tensor(~self.rt_deps['NPC'][idx], dtype=torch.float32)

    def lookup(self, rt_serial):
        out = self.rt_deps.query(f"roottracker_serial == '{rt_serial}'")
        nout = len(out)
        if nout == 1:
            return out.index[0]
        elif nout > 1:
            raise Exception("Found multiple")
        else:
            raise Exception("Found none")

    def __iter__(self):
        self.iter_idx = 0
        return self

    def __next__(self):
        if self.iter_idx < len(self):
            out = self[self.iter_idx]
            self.iter_idx += 1
            return out
        raise StopIteration
In [11]:
# Should be balanced
rt_deps_to_use['crop'].value_counts()
Out[11]:
crop
Corn    23
NPC     23
Name: count, dtype: int64

Data Loaders and data¶

Index 3 is mean velocity, while index 5 is standard deviation of velocity, both for the second cooordinate. We use both here.

In [12]:
BATCH_SIZE=rt_deps_to_use.shape[0]
BATCH_SIZE
Out[12]:
46
In [13]:
# For our work below, it is critical to use all of the data
dyn_data = Dynamics(rt_deps_to_use, os.path.join(DATA_DIR, "dynamics"), start=0, end=8043, slice=[3,5])
In [14]:
# dyn_data[0][0][:,0,0]
In [15]:
dyn_loader = DataLoader(dyn_data, batch_size=BATCH_SIZE, shuffle=False)
In [16]:
y_all = torch.stack([d[1] for d in dyn_data])
In [17]:
X_all = torch.stack([d[0] for d in dyn_data])
In [18]:
rts = rt_deps_to_use['roottracker_serial']
In [19]:
# X_all[0,:,0,0,:]

Models¶

Here we build the pieces that make up the models.

Modules¶

The soft threshold used here is not just a sigmoid function. Letting $\sigma$ be a sigmoid function and $r$ be the ReLU function, we use $$2 * r (\sigma(x) - 0.5).$$

In that case we still have a "hard" threshold at 0, but then it is soft beyond that.

In [20]:
class ShiftThenScale(nn.Module):

    def __init__(self, scale_mag=1.):
        super().__init__()
        self.shift = nn.Parameter(torch.randn(()))
        self.scale = nn.Parameter(torch.randn(()))
        self.scale_mag = scale_mag

    def forward(self, x):
        return self.scale_mag * self.scale * (x - self.shift)

    def __repr__(self):
        return f"ShiftThenScale({self.scale_mag})"
In [21]:
class SoftThreshold(nn.Module):

    def __init__(self):
        super().__init__()
        self.L1 = ShiftThenScale()
        self.L2 = ShiftThenScale()
        self.L3 = ShiftThenScale()
        self.S = nn.Sigmoid()
        self.R = nn.ReLU()

    def forward(self, x):
        last_dim = len(x.size()) - 1
        shp = list(x.shape)
        shp[last_dim] = 3
        y = torch.zeros(shp)
        y[:,:,:,:,0] = self.L1(x[:,:,:,:,0])
        x1log = torch.log(x[:,:,:,:,1] + 1e-10)
        y[:,:,:,:,1] = self.L2(x1log)
        y[:,:,:,:,2] = self.L3(x1log)
        y = self.S(y)
        y = 2.0 * (self.R(y - 0.5))
        z = torch.prod(y, dim=last_dim, keepdim=True)
        z = z.squeeze(last_dim)
        return z
In [22]:
class Integrate(nn.Module):
    def __init__(self, sum_dim):
        super().__init__()
        self.sum_dim = sum_dim

    def forward(self, x):
        n = len(x.size())
        x = torch.mean(x, dim=self.sum_dim, keepdim=True)
        return x

    def __repr__(self):
        return (f"Integrate(sum_dim={self.sum_dim}")
In [23]:
class RemoveSimultaneous(nn.Module):
    def __init__(self):
        super().__init__()
        self.R = nn.ReLU()

    def forward(self, x):
        shp = x.size()
        npads = shp[2]
        y = x.new(shp)
        for i in range(npads):
            others = [True] * npads
            others[i] = False
            val1, _ = torch.max(x[:,:,others,:], dim=2) # Max along other paddles
            val2, _ = torch.max(val1, dim=2, keepdim=True) # Max along electrodes
            # y[:,:,i,:] = x[:,:,i,:] - val1
            y[:,:,i,:] = x[:,:,i,:] - val2
        y = self.R(y)
        return y

Model definitions (wrapped as function calls)¶

In [24]:
def make_model_1():
    out = nn.Sequential(
        SoftThreshold(),
        RemoveSimultaneous(), # This makes a difference
        Integrate(sum_dim = (1, 2, 3)),
        ShiftThenScale(1.0),
        nn.Flatten(start_dim=0)
    )
    return out

Model 1¶

Here we define the parameters of the model using those from analysis-01-dynamics. From the way we constructed the model, using a hard threshold on the left, we should get nearly identical results.

In [25]:
model_1 = make_model_1()
In [26]:
# Large negative momentum_2 indicative of corn
scale1 = 100.
scale2 = 200.
model_1_defaults = {
    # Momentum less than -0.25
    '0.L1.scale': torch.tensor(-scale1),
    '0.L1.shift': torch.tensor(-0.25),
    # '0.L1.shift': torch.tensor(0.07), # This results in a LOT of detections and not very good detections
    '0.L2.scale': torch.tensor(-scale1),
    '0.L2.shift': torch.tensor(-1.0),
    '0.L3.scale': torch.tensor(scale1),
    '0.L3.shift': torch.tensor(-2.0),
    '3.shift': torch.tensor(0.0),
    '3.scale': torch.tensor(scale2)
}
In [27]:
model_1.load_state_dict(model_1_defaults)
Out[27]:
<All keys matched successfully>
In [28]:
with torch.no_grad():
    yhat_init_1 = model_1(X_all)
    print(yhat_init_1.size())
    print(yhat_init_1)
torch.Size([46])
tensor([0.0737, 0.0110, 0.0921, 0.0988, 0.0923, 0.0682, 0.0085, 0.1100, 0.0774,
        0.0082, 0.1069, 0.0649, 0.1049, 0.0260, 0.1069, 0.0184, 0.0040, 0.0196,
        0.0066, 0.0882, 0.1411, 0.1127, 0.0125, 0.0037, 0.0025, 0.0687, 0.1418,
        0.0438, 0.1007, 0.1254, 0.0867, 0.0112, 0.0117, 0.0606, 0.1672, 0.0356,
        0.0133, 0.1161, 0.0171, 0.0091, 0.0203, 0.0831, 0.1057, 0.0135, 0.1393,
        0.0792])

If you do not include RemoveSimultaneous in the model, we also see 4 NPC have high "activations" before filtering as in analysis-01-dynamics.

In [67]:
with torch.no_grad():
    plt.scatter(yhat_init_1.numpy(), y_all.numpy())
    plt.title("Ground truth vs. fraction of anomalous time (M1)")
    plt.xlabel("fraction of anomalous time")
    plt.ylabel("Plant / No plant (1 or 0)")
No description has been provided for this image

Detections¶

To recreate what we did previously, we extract the initial layers of the model which are flagging detections and then we filter those as before.

Extract¶

In [30]:
model_1_trunc = nn.Sequential(
    SoftThreshold(),
    RemoveSimultaneous()
)
In [31]:
model_1.state_dict()
Out[31]:
OrderedDict([('0.L1.shift', tensor(-0.2500)),
             ('0.L1.scale', tensor(-100.)),
             ('0.L2.shift', tensor(-1.)),
             ('0.L2.scale', tensor(-100.)),
             ('0.L3.shift', tensor(-2.)),
             ('0.L3.scale', tensor(100.)),
             ('3.shift', tensor(0.)),
             ('3.scale', tensor(200.))])
In [32]:
for idx in ['0', '1']:
    model_1_trunc._modules[idx].load_state_dict(model_1._modules[idx].state_dict())
    # print(model_1_trunc._modules[idx].state_dict())
In [33]:
Z = model_1_trunc(X_all)

Filter¶

And now we do the thresholding as in analysis-01-dynamics.

In [34]:
temp_df_list = list()
for i, z in enumerate(Z):
    temp_df = transforms1.detection_trajectories(z > 0)
    temp_df['weight'] = 1.0
    temp_df['rt_serial'] = rt_deps_to_use['rt_serial'][i]
    temp_df['crop'] = rt_deps_to_use['crop'][i] 
    temp_df_list.append(temp_df)
In [35]:
temp_df.head()
Out[35]:
start end paddle electrode weight rt_serial crop
0 3905 3906 0 8 1.0 EG00240 Corn
1 3907 3913 0 8 1.0 EG00240 Corn
2 5127 5128 0 8 1.0 EG00240 Corn
3 3896 3902 0 10 1.0 EG00240 Corn
4 4054 4057 0 10 1.0 EG00240 Corn
In [36]:
traj_df_1 = pd.concat(temp_df_list)  
In [37]:
traj_df_1.shape
Out[37]:
(12849, 7)
In [38]:
traj_df_2 = transforms1.consolidate_trajectories(traj_df_1.query("electrode > 0"), grouping=['crop', 'rt_serial', 'paddle', 'electrode'], cush=6)
In [39]:
traj_df_2.shape
Out[39]:
(7576, 8)
In [40]:
traj_df_3a = transforms1.flag_overlapping_trajectories(traj_df_2.sort_values(['crop', 'rt_serial', 'start', 'end']), 
                                                       grouping=['crop', 'rt_serial', 'paddle'], cush=6)
# traj_df_3a.head(n=10)
In [41]:
traj_df_3 = traj_df_3a.groupby(['crop', 'rt_serial', 'paddle', 'group_num']).agg({
    'start': 'min',
    'end': 'max',
    'electrode': 'min'
    }).reset_index(drop=False)
traj_df_3.shape
Out[41]:
(4820, 7)
In [42]:
# traj_df_3.head()
In [43]:
# Remove trajectories that overlap in time and on adjascent paddles (we do not use modular arithmatic here, so this is only approximate).
traj_df_4a = transforms1.flag_overlapping_trajectories(traj_df_3.sort_values(['crop', 'rt_serial', 'start', 'end']), grouping=['crop', 'rt_serial'], cush=6)
traj_df_4a['num_in_group'] = traj_df_4a.groupby(['group_num'])['group_num'].transform(lambda x: len(x))
traj_df_4a['length'] = traj_df_4a['end'] - traj_df_4a['start']
traj_df_4a.shape
Out[43]:
(4820, 10)
In [44]:
traj_df_4a.head()
Out[44]:
crop rt_serial paddle group_num start end electrode overlap_prev num_in_group length
0 Corn EG00054 10 1 36 37 20 False 1 1
1 Corn EG00054 10 2 346 351 20 False 1 5
2 Corn EG00054 10 3 378 396 20 False 1 18
3 Corn EG00054 2 4 532 533 12 False 6 1
4 Corn EG00054 10 4 534 550 8 True 6 16

There are a lot of trajectories that occur on just 1 or 2 paddles. If you just keep the number of trajectories that have only one active paddle, then you filter 2/3 of your detections! Whether we use 1 or 1 and 2, we still get reasonable separation of NPC and Corn below. We could also do this after the fact by modeling.

In [45]:
traj_df_4a['num_in_group'].value_counts()
Out[45]:
num_in_group
1     1837
2      976
3      621
4      444
5      300
6      276
7      168
8       88
9       45
11      33
10      20
12      12
Name: count, dtype: int64
In [46]:
pd.crosstab(traj_df_4a['num_in_group'], traj_df_4a['crop'])
Out[46]:
crop Corn NPC
num_in_group
1 1462 375
2 764 212
3 468 153
4 312 132
5 230 70
6 192 84
7 112 56
8 64 24
9 27 18
10 20 0
11 11 22
12 12 0
In [47]:
# Lots of overlap in this NPC, but it doesn't all align
# traj_df_4a.query("rt_serial == 'EG00103' & num_in_group == 6") 
In [48]:
# We remove 2/3 of trajectories by removing things at occur at the same time!!!
traj_df_4 = traj_df_4a.query("num_in_group <= 1")
traj_df_4.shape
Out[48]:
(1837, 10)
In [49]:
# traj_df_4.head()
In [50]:
traj_df_5 = traj_df_4.copy()
# traj_df_5 = traj_df_4.merge(crop_df, on='rt_serial')
traj_df_5.shape
Out[50]:
(1837, 10)
In [51]:
# traj_df_5['length'] = traj_df_5['end'] - traj_df_5['start']
traj_df_5['day']= np.floor(traj_df_5['start'] / (12 * 24)).astype(int)
traj_df_5.head()
Out[51]:
crop rt_serial paddle group_num start end electrode overlap_prev num_in_group length day
0 Corn EG00054 10 1 36 37 20 False 1 1 0
1 Corn EG00054 10 2 346 351 20 False 1 5 1
2 Corn EG00054 10 3 378 396 20 False 1 18 1
9 Corn EG00054 11 5 641 668 8 False 1 27 2
15 Corn EG00054 0 8 972 998 7 False 1 26 3
In [ ]:
 

Plot¶

Let's look at the same plots we created previously. They all look very similar.

In [52]:
gg.ggplot(traj_df_5) + gg.geom_histogram(gg.aes(x='length', fill='crop'), bins=20)
No description has been provided for this image
Out[52]:
<Figure Size: (640 x 480)>
In [53]:
# We integrate here, which can be critical to getting a difference if you do not filter out short-length anomalies.
traj_df_6 = traj_df_5.query("length > 6").groupby(['crop', 'rt_serial']).agg({'length': 'sum', 'start': 'size'}).reset_index(drop=False)
traj_df_6['count'] = traj_df_6.pop('start')
traj_df_6.shape
Out[53]:
(46, 4)
In [54]:
# traj_df_6.head()

Looking at differences in cumulative time:

In [55]:
gg.ggplot(traj_df_6) + gg.geom_histogram(gg.aes(x='length', fill='crop'), bins=20) + \
  gg.ggtitle("Distribution of cumulative anomalous trajectory time by crop")
No description has been provided for this image
Out[55]:
<Figure Size: (640 x 480)>

Looking at differences in counts:

In [56]:
gg.ggplot(traj_df_6) + gg.geom_histogram(gg.aes(x='count', fill='crop'), bins=20)  + \
  gg.ggtitle("Distribution of number of anomalous trajectories by crop")
No description has been provided for this image
Out[56]:
<Figure Size: (640 x 480)>
In [57]:
gg.ggplot(traj_df_5.query("length > 3")) + gg.geom_point(gg.aes('start', 'rt_serial', fill='crop')) + \
  gg.ggtitle("Detection times by device")
No description has been provided for this image
Out[57]:
<Figure Size: (640 x 480)>
In [58]:
traj_df_day = traj_df_5.groupby(['crop', 'rt_serial', 'day']).agg({'length': 'sum', 'start': 'size', 'electrode': 'mean'}).reset_index(drop=False)
traj_df_day['day'] = pd.Categorical(traj_df_day['day'], ordered=True)
# traj_df_day.head()
In [59]:
gg.ggplot(traj_df_day) + gg.geom_boxplot(gg.aes('day', 'start')) + gg.facet_wrap("~ crop") + \
  gg.ggtitle("Distribution of number of detections by day")
No description has been provided for this image
Out[59]:
<Figure Size: (640 x 480)>
In [60]:
gg.ggplot(traj_df_day) + gg.geom_boxplot(gg.aes('day', 'electrode')) + gg.facet_wrap("~ crop") + \
  gg.ggtitle("Distribution of number of detections by day")
No description has been provided for this image
Out[60]:
<Figure Size: (640 x 480)>

Conclusion (Part A)¶

Above, we showed that the soft thresholded version of our work from analysis-01-dynamics does indeed produce very similar results. In other words, we can construct a model (a neural network) that recapitulates the ad hoc approach that worked well previously. In part B, we show that we can learn the optimal parameters of this model from data. (See inference-01-dynamics-bivariate-B.)

In [ ]: