Introduction¶

We saw in Part A, that one could use soft thresholding to create a model that recapitulated the results from analysis-01-dynamics. Here we want to show that we can learn the optimal parameters of such a model.

Setup¶

Import stuff and determine device we will use for computations

In [2]:
import os
import copy
import h5py
import pandas as pd
import numpy as np
import progressbar
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from matplotlib import pyplot as plt
import plotnine as gg
import importlib as il
In [3]:
%matplotlib inline
In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
Out[4]:
device(type='cuda', index=0)

Data to use¶

Get RootTracker deployment data.

In [5]:
DATA_LOC = "google"
In [6]:
DATA_DIR = "/usr/local/share/rtphenos-workspace"
if DATA_LOC == "google":
    from google.colab import drive
    drive.mount("/content/drive")
    DATA_DIR = "/content/drive/My Drive/rtphenos-workspace"
Mounted at /content/drive
In [7]:
rt_deps = pd.read_csv(os.path.join(DATA_DIR, "deployments", "rt_acc1_deps.csv"))
rt_deps.head()
Out[7]:
deployment_id roottracker_serial ram_serial x y bench crop genotype order pot_on_bench pot_on_row rep row deployment_name NPC
0 54 EG00043 D76080D25B7EA126 1.0 11.0 2 Soy PI416937 11 2 11 1 1 ACC Test 1 False
1 54 EG00044 D76080D25B3BD828 1.0 3.0 1 Wheat Comet 3 3 3 1 1 ACC Test 1 False
2 54 EG00045 D76080D25B7F7526 4.0 7.0 1 Cotton PI513384 115 7 7 1 4 ACC Test 1 False
3 54 EG00046 D76080D25B3C9628 2.0 11.0 2 Soy PI416937 47 2 11 1 2 ACC Test 1 False
4 54 EG00047 D76080D25B3A6C28 3.0 8.0 1 Cotton PI513384 80 8 8 1 3 ACC Test 1 False

We want to limit our attention to just corn and NPC devices, since we believe that will have the clearest separation. We also want to remove devices with odd data.

The first definition of bad_roottrackers corresponds to those with odd data coming in on paddles 2 and 3. The second corresponds that and the top 8 devices in terms of integrated, mean-velocity-below-a-threshold, which is balanced between corn and NPC. The filtering we do in our previous analysis is critical for eliminating these NPCs. Here we try to do something similar using the RemoveSimultaneous module defined below.

In [8]:
bad_roottrackers = ['EG00237', 'EG00226']
# bad_roottrackers = ['EG00237', 'EG00226', 'EG00228', 'EG00062', 'EG00157', 'EG00096', 'EG00160', 'EG00103', 'EG00159', 'EG00140']
In [9]:
rt_deps_to_use = rt_deps.query(f"(crop == 'Corn' | crop == 'NPC') & ~(roottracker_serial in {bad_roottrackers})").reset_index(drop=True).reset_index(drop=False)
rt_deps_to_use['rt_serial'] = rt_deps_to_use['roottracker_serial']
rt_deps_to_use.shape
Out[9]:
(46, 17)

Load data¶

We use the DataSet class from PyTorch to access our data.

In [10]:
class Dynamics(Dataset):

    def __init__(self, rt_deps, base_dir, start=2000, end=4000, slice=[3,5]):
        self.iter_idx = 0
        if isinstance(rt_deps, str):
            self.rt_deps = pd.read_csv(deps_file)
        elif isinstance(rt_deps, pd.DataFrame):
            self.rt_deps = rt_deps.reset_index()
        else:
            raise ValueError("rt_deps must be a path to a csv file or a Dataframe")
        self.base_dir = base_dir
        self.key = 'dyn'
        self.cache = [None for i in range(len(self.rt_deps))]
        self.start = start
        self.end = end
        self.slice = slice

    def __len__(self):
        return self.rt_deps.shape[0]

    def get_path(self, idx):
        rt_serial = self.rt_deps['roottracker_serial'][idx]
        full_path = os.path.join(self.base_dir, self.rt_deps['deployment_name'][idx], f"{rt_serial}.h5")
        if not os.path.exists(full_path):
            raise Exception(f"Path {full_path} does not exist")
        return full_path

    def __getitem__(self, idx):
        out = None
        if self.cache[idx] is None:
            with h5py.File(self.get_path(idx), 'r') as h5r:
                self.cache[idx] = torch.tensor(
                    np.array(
                        h5r[self.key][self.start:self.end,:,:,:]
                    )[:,:,:,self.slice]
                ).float()
        return self.cache[idx], torch.tensor(~self.rt_deps['NPC'][idx], dtype=torch.float32)

    def lookup(self, rt_serial):
        out = self.rt_deps.query(f"roottracker_serial == '{rt_serial}'")
        nout = len(out)
        if nout == 1:
            return out.index[0]
        elif nout > 1:
            raise Exception("Found multiple")
        else:
            raise Exception("Found none")

    def __iter__(self):
        self.iter_idx = 0
        return self

    def __next__(self):
        if self.iter_idx < len(self):
            out = self[self.iter_idx]
            self.iter_idx += 1
            return out
        raise StopIteration
In [11]:
# Should be balanced
rt_deps_to_use['crop'].value_counts()
Out[11]:
Corn    23
NPC     23
Name: crop, dtype: int64

Data Loaders and data¶

In [12]:
BATCH_SIZE=rt_deps_to_use.shape[0]
BATCH_SIZE
Out[12]:
46
In [13]:
# For our work below, it is critical to use all of the data
dyn_data = Dynamics(rt_deps_to_use, os.path.join(DATA_DIR, "dynamics"), start=0, end=8043, slice=[3,5])
In [14]:
# dyn_data[0][0][:,0,0]
In [15]:
dyn_loader = DataLoader(dyn_data, batch_size=BATCH_SIZE, shuffle=False)
In [16]:
y_all = torch.stack([d[1] for d in dyn_data])
In [17]:
X_all = torch.stack([d[0] for d in dyn_data])
In [18]:
rts = rt_deps_to_use['roottracker_serial']
In [19]:
# X_all[0,:,0,0,:]

Models¶

Here we build the pieces that make up the models.

Modules¶

The soft threshold used here is not just a sigmoid function. Letting $\sigma$ be a sigmoid function and $r$ be the ReLU function, we use $$2 * r (\sigma(x) - 0.5).$$ In that case we still have a "hard" threshold at 0, but then it is soft beyond that.

We use a slightly different model than in Part A. In particular, we use the MyAbs module below for the log variance, which creates a finite interval of high activity. Effectively, we soft threshold the mean velocity to be below a certain threshold and we soft threshold the log variance of the velocity to be within an interval.

In [20]:
class ShiftThenScale(nn.Module):

    def __init__(self, scale_mag=10.):
        super().__init__()
        self.shift = nn.Parameter(torch.randn(()))
        self.scale = nn.Parameter(torch.randn(()))
        self.scale_mag = scale_mag

    def forward(self, x):
        return self.scale_mag * self.scale * (x - self.shift)

    def __repr__(self):
        return f"ShiftThenScale({self.scale_mag})"
In [21]:
class MyAbs(nn.Module):

    def __init__(self, scale_mag=-1.):
        super().__init__()
        self.shift = nn.Parameter(torch.randn(()))
        self.log_half_width = nn.Parameter(torch.randn(()))
        self.log_scale = nn.Parameter(torch.randn(()))
        self.scale_mag = scale_mag

    def forward(self, x):
        return self.scale_mag * torch.exp(self.log_scale) * (torch.exp(self.log_half_width) -  torch.abs(x - self.shift))
In [22]:
class SoftThreshold2(nn.Module):

    def __init__(self):
        super().__init__()
        self.L = ShiftThenScale()
        self.A = MyAbs()
        self.S = nn.Sigmoid()
        self.R = nn.ReLU()

    def forward(self, x):
        last_dim = len(x.size()) - 1
        shp = list(x.shape)
        shp[last_dim] = 2
        # y = torch.zeros_like(x)
        y = torch.zeros(shp, dtype=x.dtype, layout=x.layout, device=x.device)
        y[:,:,:,:,0] = self.L(x[:,:,:,:,0])
        x1log = torch.log(x[:,:,:,:,1] + 1e-10)
        y[:,:,:,:,1] = self.A(x1log)
        y = self.S(y)
        y = 2.0 * (self.R(y - 0.5))
        z = torch.prod(y, dim=last_dim, keepdim=True)
        z = z.squeeze(last_dim)
        return z
In [23]:
class SoftThreshold3(nn.Module):

    def __init__(self):
        super().__init__()
        self.L = nn.Linear(1,1)
        self.A = MyAbs()
        self.S = nn.Sigmoid()
        self.R = nn.ReLU()

    def forward(self, x):
        last_dim = len(x.size()) - 1
        shp = list(x.shape)
        shp[last_dim] = 2
        # y = torch.zeros_like(x)
        y = torch.zeros(shp, dtype=x.dtype, layout=x.layout, device=x.device)
        y[:,:,:,:,0] = self.L(x[:,:,:,:,[0]])[:,:,:,:,0]
        x1log = torch.log(x[:,:,:,:,1] + 1e-10)
        y[:,:,:,:,1] = self.A(x1log)
        y = self.S(y)
        y = 2.0 * (self.R(y - 0.5))
        z = torch.prod(y, dim=last_dim, keepdim=True)
        z = z.squeeze(last_dim)
        return z
In [24]:
class Integrate(nn.Module):
    def __init__(self, sum_dim):
        super().__init__()
        self.sum_dim = sum_dim

    def forward(self, x):
        n = len(x.size())
        x = torch.mean(x, dim=self.sum_dim, keepdim=True)
        return x

    def __repr__(self):
        return (f"Integrate(sum_dim={self.sum_dim}")
In [25]:
class Standardize(nn.Module):

    def __init__(self, return_both=True, use_sqrt=False):
        super().__init__()
        self.return_both = return_both
        self.use_sqrt = use_sqrt

    def forward(self, x):
        xs = x.squeeze()
        if self.use_sqrt:
            xs = torch.sqrt(xs + 1e-10)
        # Median seems like a more intuitive approach, but mean seems better in terms of learning
        # m = torch.mean(xs) # but this makes things more sensitive to all data
        # s = torch.std(xs) # you don't gain anything by squashing things at 50% here
        m = torch.median(xs) # ensures balance, and could be used with known imbalance
        s = torch.sqrt(torch.mean(torch.square(xs - m)))
        y = (xs - m) / s
        if self.return_both:
            return torch.stack([y, xs])
        else:
            return y
In [26]:
class RemoveSimultaneous(nn.Module):
    def __init__(self):
        super().__init__()
        self.R = nn.ReLU()

    def forward(self, x):
        shp = x.size()
        npads = shp[2]
        y = x.new(shp)
        for i in range(npads):
            others = [True] * npads
            others[i] = False
            val1, _ = torch.max(x[:,:,others,:], dim=2) # Max along other paddles
            val2, _ = torch.max(val1, dim=2, keepdim=True) # Max along electrodes
            # y[:,:,i,:] = x[:,:,i,:] - val1
            y[:,:,i,:] = x[:,:,i,:] - val2
        y = self.R(y)
        return y

Model definitions (wrapped as function calls)¶

In [26]:
 
In [27]:
def make_model_2():
    out = nn.Sequential(
        SoftThreshold2(),
        RemoveSimultaneous(), # Takes a long time without GPU
        Integrate(sum_dim = (1, 2, 3)),
        nn.Flatten(start_dim=1),
        nn.BatchNorm1d(1, momentum = 0.0),
        nn.Flatten(start_dim=0)
    )
    return out
In [28]:
def myopt(device, optimizer, model, criterion, data_loader, iter=50):
    widgets = [
        progressbar.Percentage(), " ",
        progressbar.Bar(), " ",
        progressbar.ETA(), ", ",
        progressbar.Variable('loss', precision=6),
    ]
    with progressbar.ProgressBar(max_value=iter, widgets=widgets) as bar:
        for epoch in range(iter):  # loop over the dataset multiple times
            running_loss = 0.0
            for i, data in enumerate(dyn_loader, 0): # There should just be 1 batch
                inputs_cpu, labels_cpu = data
                labels_cpu = 1*labels_cpu
                inputs = inputs_cpu.to(device)
                labels = labels_cpu.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                outputs_cpu = outputs.cpu()
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
            bar.update(epoch, loss=loss.item())
    return outputs_cpu, labels_cpu

Fitting a model¶

Model¶

Here we initialize the model with values that we know work, i.e. those found in analysis-01-dynamics and Part A. We then allow the optimization to refine those parameters.

As before, we print and plot the results of the initialization, optimize, and then print and plot the results afterwards.

In [29]:
KNOWN_START = True
In [30]:
criterion = nn.BCEWithLogitsLoss()
In [31]:
# torch.manual_seed(1515) # low mom 2
# torch.manual_seed(1530) # high mom 2
In [32]:
model_2 = make_model_2()
In [33]:
model_2_init = model_2.state_dict()
model_2_init
Out[33]:
OrderedDict([('0.L.shift', tensor(0.7376)),
             ('0.L.scale', tensor(1.4841)),
             ('0.A.shift', tensor(1.8207)),
             ('0.A.log_half_width', tensor(-0.2463)),
             ('0.A.log_scale', tensor(0.8489)),
             ('4.weight', tensor([1.])),
             ('4.bias', tensor([0.])),
             ('4.running_mean', tensor([0.])),
             ('4.running_var', tensor([1.])),
             ('4.num_batches_tracked', tensor(0))])
In [34]:
scale1 = 10.
scale2 = 20.
if KNOWN_START:
  model_2_defaults = copy.deepcopy(model_2_init)
  model_2_changes = {
       '0.L.scale': torch.tensor(-scale1),
        '0.L.shift': torch.tensor(-0.25),
        '0.A.log_scale': torch.tensor(np.log(scale1)),
        '0.A.shift': torch.tensor(-1.5),
        '0.A.log_half_width': torch.tensor(np.log(0.5)),
    }
  for k, v in model_2_changes.items():
      model_2_defaults[k] = v
else:
    model_2_defaults = copy.deepcopy(model_2_init)
In [35]:
model_2.load_state_dict(model_2_defaults)
Out[35]:
<All keys matched successfully>
In [36]:
model_2.state_dict()
Out[36]:
OrderedDict([('0.L.shift', tensor(-0.2500)),
             ('0.L.scale', tensor(-10.)),
             ('0.A.shift', tensor(-1.5000)),
             ('0.A.log_half_width', tensor(-0.6931)),
             ('0.A.log_scale', tensor(2.3026)),
             ('4.weight', tensor([1.])),
             ('4.bias', tensor([0.])),
             ('4.running_mean', tensor([0.])),
             ('4.running_var', tensor([1.])),
             ('4.num_batches_tracked', tensor(0))])
In [37]:
with torch.no_grad():
    yhat_init_2 = model_2(X_all)
    # print(yhat_init_2.size())
    # print(yhat_init_2)
In [53]:
with torch.no_grad():
    plt.scatter(yhat_init_2.numpy(), y_all.numpy())
    plt.title("Ground truth versus predicted log odds")
    plt.xlabel("log odds")
    plt.ylabel("Plant / No plant (1 or 0)")
No description has been provided for this image
In [39]:
with torch.no_grad():
    print(criterion(yhat_init_2, y_all))
tensor(0.6606)
In [40]:
model_2.to(device)
Out[40]:
Sequential(
  (0): SoftThreshold2(
    (L): ShiftThenScale(10.0)
    (A): MyAbs()
    (S): Sigmoid()
    (R): ReLU()
  )
  (1): RemoveSimultaneous(
    (R): ReLU()
  )
  (2): Integrate(sum_dim=(1, 2, 3)
  (3): Flatten(start_dim=1, end_dim=-1)
  (4): BatchNorm1d(1, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True)
  (5): Flatten(start_dim=0, end_dim=-1)
)
In [41]:
optimizer_2 = optim.Adam(model_2.parameters(), lr=1e-1, weight_decay=0.0)
In [42]:
outputs, labels = myopt(device, optimizer_2, model_2, criterion, dyn_loader, iter=80)
100% |#########################################| Time:  0:13:06, loss: 0.370897
In [43]:
# with torch.no_grad():
#     print(outputs, labels)
In [52]:
with torch.no_grad():
    plt.scatter(outputs.numpy(), labels.numpy())
    plt.title("Ground truth versus predicted log odds")
    plt.xlabel("log odds")
    plt.ylabel("Plant / No plant (1 or 0)")
No description has been provided for this image
In [45]:
[(name, theta, theta.grad) for name, theta in model_2.named_parameters()]
Out[45]:
[('0.L.shift',
  Parameter containing:
  tensor(-0.2964, device='cuda:0', requires_grad=True),
  tensor(0.0182, device='cuda:0')),
 ('0.L.scale',
  Parameter containing:
  tensor(-12.0017, device='cuda:0', requires_grad=True),
  tensor(2.5145e-05, device='cuda:0')),
 ('0.A.shift',
  Parameter containing:
  tensor(-3.0548, device='cuda:0', requires_grad=True),
  tensor(4.5884e-07, device='cuda:0')),
 ('0.A.log_half_width',
  Parameter containing:
  tensor(-2.0869, device='cuda:0', requires_grad=True),
  tensor(5.6932e-08, device='cuda:0')),
 ('0.A.log_scale',
  Parameter containing:
  tensor(3.7463, device='cuda:0', requires_grad=True),
  tensor(-8.5860e-08, device='cuda:0')),
 ('4.weight',
  Parameter containing:
  tensor([8.5806], device='cuda:0', requires_grad=True),
  tensor([-0.0274], device='cuda:0')),
 ('4.bias',
  Parameter containing:
  tensor([-0.0011], device='cuda:0', requires_grad=True),
  tensor([0.0008], device='cuda:0'))]
In [46]:
# We end up with a pretty narrow band for the standard deviation.
with torch.no_grad():
  cent = model_2._modules['0'].A.shift.cpu().numpy()
  hw = torch.exp(model_2._modules['0'].A.log_half_width).cpu().numpy()
  print("Log var center: {:.2f}, log var half width: {:.2f}, sd low, high: {:.2f}, {:.2f}".format(
      cent, hw, np.exp(0.5*(cent - hw)), np.exp(0.5*(cent + hw))
  ))
Log var center: -3.05, log var half width: 0.12, sd low, high: 0.20, 0.23

Enforcing sparsity in NPC¶

Another option for improving model performance is to build into the loss function a penalty for excessive detections in the NPC. We don't just want to separate the two groups, we want to eliminate detections from NPC.

One challenge in doing this is figuring out the correct trade-off between the original criterion and then the "sparsity" we are trying to enforce. Presumably, one could use a reverse-lasso type approach and start with no penalty and then subsequently increase it.

Here the hypothetical output from the model includes the log-odds values as well as the values after the integration step. While we have implemented this approach, we forego that here, since the results above suffice; however, we mention it for the sake of completeness.

In [47]:
# Penalize NPC - finding the right balance is key here
def criterion_with_sparsity(myout, labels, pen=1):
    return criterion(myout[0,:], labels) + 0.69 * pen * torch.mean((1 - labels) * myout[1,:])

# def criterion3(myout, labels):
#     return torch.mean((1 - labels) * myout[1,:]) + torch.mean(labels / myout[1,:])

Conclusion¶

Herein we have shown that we can formalize the ad hoc method from analysis-01-dynamics as a neural network model and learn the optimal set of parameters for this model given data.

The data here consists of only 46 cases, which is sufficient for learning the parameters of our model. However, our approach suggests that by employing a larger dataset we could consider larger models, i.e. a deep neural network to capture more complex patterns distinguishing plants from controls.

In [47]: