Introduction¶
We saw in Part A, that one could use soft thresholding to create a model that recapitulated the results from analysis-01-dynamics
. Here we want to show that we can learn the optimal parameters of such a model.
Setup¶
Import stuff and determine device we will use for computations
import os
import copy
import h5py
import pandas as pd
import numpy as np
import progressbar
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from matplotlib import pyplot as plt
import plotnine as gg
import importlib as il
%matplotlib inline
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
device(type='cuda', index=0)
Data to use¶
Get RootTracker deployment data.
DATA_LOC = "google"
DATA_DIR = "/usr/local/share/rtphenos-workspace"
if DATA_LOC == "google":
from google.colab import drive
drive.mount("/content/drive")
DATA_DIR = "/content/drive/My Drive/rtphenos-workspace"
Mounted at /content/drive
rt_deps = pd.read_csv(os.path.join(DATA_DIR, "deployments", "rt_acc1_deps.csv"))
rt_deps.head()
deployment_id | roottracker_serial | ram_serial | x | y | bench | crop | genotype | order | pot_on_bench | pot_on_row | rep | row | deployment_name | NPC | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 54 | EG00043 | D76080D25B7EA126 | 1.0 | 11.0 | 2 | Soy | PI416937 | 11 | 2 | 11 | 1 | 1 | ACC Test 1 | False |
1 | 54 | EG00044 | D76080D25B3BD828 | 1.0 | 3.0 | 1 | Wheat | Comet | 3 | 3 | 3 | 1 | 1 | ACC Test 1 | False |
2 | 54 | EG00045 | D76080D25B7F7526 | 4.0 | 7.0 | 1 | Cotton | PI513384 | 115 | 7 | 7 | 1 | 4 | ACC Test 1 | False |
3 | 54 | EG00046 | D76080D25B3C9628 | 2.0 | 11.0 | 2 | Soy | PI416937 | 47 | 2 | 11 | 1 | 2 | ACC Test 1 | False |
4 | 54 | EG00047 | D76080D25B3A6C28 | 3.0 | 8.0 | 1 | Cotton | PI513384 | 80 | 8 | 8 | 1 | 3 | ACC Test 1 | False |
We want to limit our attention to just corn and NPC devices, since we believe that will have the clearest separation. We also want to remove devices with odd data.
The first definition of bad_roottrackers
corresponds to those with odd data coming in on paddles 2 and 3. The second corresponds that and the top 8 devices in terms of integrated, mean-velocity-below-a-threshold, which is balanced between corn and NPC. The filtering we do in our previous analysis is critical for eliminating these NPCs. Here we try to do something similar using the RemoveSimultaneous
module defined below.
bad_roottrackers = ['EG00237', 'EG00226']
# bad_roottrackers = ['EG00237', 'EG00226', 'EG00228', 'EG00062', 'EG00157', 'EG00096', 'EG00160', 'EG00103', 'EG00159', 'EG00140']
rt_deps_to_use = rt_deps.query(f"(crop == 'Corn' | crop == 'NPC') & ~(roottracker_serial in {bad_roottrackers})").reset_index(drop=True).reset_index(drop=False)
rt_deps_to_use['rt_serial'] = rt_deps_to_use['roottracker_serial']
rt_deps_to_use.shape
(46, 17)
Load data¶
We use the DataSet class from PyTorch to access our data.
class Dynamics(Dataset):
def __init__(self, rt_deps, base_dir, start=2000, end=4000, slice=[3,5]):
self.iter_idx = 0
if isinstance(rt_deps, str):
self.rt_deps = pd.read_csv(deps_file)
elif isinstance(rt_deps, pd.DataFrame):
self.rt_deps = rt_deps.reset_index()
else:
raise ValueError("rt_deps must be a path to a csv file or a Dataframe")
self.base_dir = base_dir
self.key = 'dyn'
self.cache = [None for i in range(len(self.rt_deps))]
self.start = start
self.end = end
self.slice = slice
def __len__(self):
return self.rt_deps.shape[0]
def get_path(self, idx):
rt_serial = self.rt_deps['roottracker_serial'][idx]
full_path = os.path.join(self.base_dir, self.rt_deps['deployment_name'][idx], f"{rt_serial}.h5")
if not os.path.exists(full_path):
raise Exception(f"Path {full_path} does not exist")
return full_path
def __getitem__(self, idx):
out = None
if self.cache[idx] is None:
with h5py.File(self.get_path(idx), 'r') as h5r:
self.cache[idx] = torch.tensor(
np.array(
h5r[self.key][self.start:self.end,:,:,:]
)[:,:,:,self.slice]
).float()
return self.cache[idx], torch.tensor(~self.rt_deps['NPC'][idx], dtype=torch.float32)
def lookup(self, rt_serial):
out = self.rt_deps.query(f"roottracker_serial == '{rt_serial}'")
nout = len(out)
if nout == 1:
return out.index[0]
elif nout > 1:
raise Exception("Found multiple")
else:
raise Exception("Found none")
def __iter__(self):
self.iter_idx = 0
return self
def __next__(self):
if self.iter_idx < len(self):
out = self[self.iter_idx]
self.iter_idx += 1
return out
raise StopIteration
# Should be balanced
rt_deps_to_use['crop'].value_counts()
Corn 23 NPC 23 Name: crop, dtype: int64
Data Loaders and data¶
BATCH_SIZE=rt_deps_to_use.shape[0]
BATCH_SIZE
46
# For our work below, it is critical to use all of the data
dyn_data = Dynamics(rt_deps_to_use, os.path.join(DATA_DIR, "dynamics"), start=0, end=8043, slice=[3,5])
# dyn_data[0][0][:,0,0]
dyn_loader = DataLoader(dyn_data, batch_size=BATCH_SIZE, shuffle=False)
y_all = torch.stack([d[1] for d in dyn_data])
X_all = torch.stack([d[0] for d in dyn_data])
rts = rt_deps_to_use['roottracker_serial']
# X_all[0,:,0,0,:]
Models¶
Here we build the pieces that make up the models.
Modules¶
The soft threshold used here is not just a sigmoid function. Letting $\sigma$ be a sigmoid function and $r$ be the ReLU function, we use $$2 * r (\sigma(x) - 0.5).$$ In that case we still have a "hard" threshold at 0, but then it is soft beyond that.
We use a slightly different model than in Part A. In particular, we use the MyAbs
module below for the log variance, which creates a finite interval of high activity. Effectively, we soft threshold the mean velocity to be below a certain threshold and we soft threshold the log variance of the velocity to be within an interval.
class ShiftThenScale(nn.Module):
def __init__(self, scale_mag=10.):
super().__init__()
self.shift = nn.Parameter(torch.randn(()))
self.scale = nn.Parameter(torch.randn(()))
self.scale_mag = scale_mag
def forward(self, x):
return self.scale_mag * self.scale * (x - self.shift)
def __repr__(self):
return f"ShiftThenScale({self.scale_mag})"
class MyAbs(nn.Module):
def __init__(self, scale_mag=-1.):
super().__init__()
self.shift = nn.Parameter(torch.randn(()))
self.log_half_width = nn.Parameter(torch.randn(()))
self.log_scale = nn.Parameter(torch.randn(()))
self.scale_mag = scale_mag
def forward(self, x):
return self.scale_mag * torch.exp(self.log_scale) * (torch.exp(self.log_half_width) - torch.abs(x - self.shift))
class SoftThreshold2(nn.Module):
def __init__(self):
super().__init__()
self.L = ShiftThenScale()
self.A = MyAbs()
self.S = nn.Sigmoid()
self.R = nn.ReLU()
def forward(self, x):
last_dim = len(x.size()) - 1
shp = list(x.shape)
shp[last_dim] = 2
# y = torch.zeros_like(x)
y = torch.zeros(shp, dtype=x.dtype, layout=x.layout, device=x.device)
y[:,:,:,:,0] = self.L(x[:,:,:,:,0])
x1log = torch.log(x[:,:,:,:,1] + 1e-10)
y[:,:,:,:,1] = self.A(x1log)
y = self.S(y)
y = 2.0 * (self.R(y - 0.5))
z = torch.prod(y, dim=last_dim, keepdim=True)
z = z.squeeze(last_dim)
return z
class SoftThreshold3(nn.Module):
def __init__(self):
super().__init__()
self.L = nn.Linear(1,1)
self.A = MyAbs()
self.S = nn.Sigmoid()
self.R = nn.ReLU()
def forward(self, x):
last_dim = len(x.size()) - 1
shp = list(x.shape)
shp[last_dim] = 2
# y = torch.zeros_like(x)
y = torch.zeros(shp, dtype=x.dtype, layout=x.layout, device=x.device)
y[:,:,:,:,0] = self.L(x[:,:,:,:,[0]])[:,:,:,:,0]
x1log = torch.log(x[:,:,:,:,1] + 1e-10)
y[:,:,:,:,1] = self.A(x1log)
y = self.S(y)
y = 2.0 * (self.R(y - 0.5))
z = torch.prod(y, dim=last_dim, keepdim=True)
z = z.squeeze(last_dim)
return z
class Integrate(nn.Module):
def __init__(self, sum_dim):
super().__init__()
self.sum_dim = sum_dim
def forward(self, x):
n = len(x.size())
x = torch.mean(x, dim=self.sum_dim, keepdim=True)
return x
def __repr__(self):
return (f"Integrate(sum_dim={self.sum_dim}")
class Standardize(nn.Module):
def __init__(self, return_both=True, use_sqrt=False):
super().__init__()
self.return_both = return_both
self.use_sqrt = use_sqrt
def forward(self, x):
xs = x.squeeze()
if self.use_sqrt:
xs = torch.sqrt(xs + 1e-10)
# Median seems like a more intuitive approach, but mean seems better in terms of learning
# m = torch.mean(xs) # but this makes things more sensitive to all data
# s = torch.std(xs) # you don't gain anything by squashing things at 50% here
m = torch.median(xs) # ensures balance, and could be used with known imbalance
s = torch.sqrt(torch.mean(torch.square(xs - m)))
y = (xs - m) / s
if self.return_both:
return torch.stack([y, xs])
else:
return y
class RemoveSimultaneous(nn.Module):
def __init__(self):
super().__init__()
self.R = nn.ReLU()
def forward(self, x):
shp = x.size()
npads = shp[2]
y = x.new(shp)
for i in range(npads):
others = [True] * npads
others[i] = False
val1, _ = torch.max(x[:,:,others,:], dim=2) # Max along other paddles
val2, _ = torch.max(val1, dim=2, keepdim=True) # Max along electrodes
# y[:,:,i,:] = x[:,:,i,:] - val1
y[:,:,i,:] = x[:,:,i,:] - val2
y = self.R(y)
return y
Model definitions (wrapped as function calls)¶
def make_model_2():
out = nn.Sequential(
SoftThreshold2(),
RemoveSimultaneous(), # Takes a long time without GPU
Integrate(sum_dim = (1, 2, 3)),
nn.Flatten(start_dim=1),
nn.BatchNorm1d(1, momentum = 0.0),
nn.Flatten(start_dim=0)
)
return out
def myopt(device, optimizer, model, criterion, data_loader, iter=50):
widgets = [
progressbar.Percentage(), " ",
progressbar.Bar(), " ",
progressbar.ETA(), ", ",
progressbar.Variable('loss', precision=6),
]
with progressbar.ProgressBar(max_value=iter, widgets=widgets) as bar:
for epoch in range(iter): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(dyn_loader, 0): # There should just be 1 batch
inputs_cpu, labels_cpu = data
labels_cpu = 1*labels_cpu
inputs = inputs_cpu.to(device)
labels = labels_cpu.to(device)
optimizer.zero_grad()
outputs = model(inputs)
outputs_cpu = outputs.cpu()
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
bar.update(epoch, loss=loss.item())
return outputs_cpu, labels_cpu
Fitting a model¶
Model¶
Here we initialize the model with values that we know work, i.e. those found in analysis-01-dynamics
and Part A. We then allow the optimization to refine those parameters.
As before, we print and plot the results of the initialization, optimize, and then print and plot the results afterwards.
KNOWN_START = True
criterion = nn.BCEWithLogitsLoss()
# torch.manual_seed(1515) # low mom 2
# torch.manual_seed(1530) # high mom 2
model_2 = make_model_2()
model_2_init = model_2.state_dict()
model_2_init
OrderedDict([('0.L.shift', tensor(0.7376)), ('0.L.scale', tensor(1.4841)), ('0.A.shift', tensor(1.8207)), ('0.A.log_half_width', tensor(-0.2463)), ('0.A.log_scale', tensor(0.8489)), ('4.weight', tensor([1.])), ('4.bias', tensor([0.])), ('4.running_mean', tensor([0.])), ('4.running_var', tensor([1.])), ('4.num_batches_tracked', tensor(0))])
scale1 = 10.
scale2 = 20.
if KNOWN_START:
model_2_defaults = copy.deepcopy(model_2_init)
model_2_changes = {
'0.L.scale': torch.tensor(-scale1),
'0.L.shift': torch.tensor(-0.25),
'0.A.log_scale': torch.tensor(np.log(scale1)),
'0.A.shift': torch.tensor(-1.5),
'0.A.log_half_width': torch.tensor(np.log(0.5)),
}
for k, v in model_2_changes.items():
model_2_defaults[k] = v
else:
model_2_defaults = copy.deepcopy(model_2_init)
model_2.load_state_dict(model_2_defaults)
<All keys matched successfully>
model_2.state_dict()
OrderedDict([('0.L.shift', tensor(-0.2500)), ('0.L.scale', tensor(-10.)), ('0.A.shift', tensor(-1.5000)), ('0.A.log_half_width', tensor(-0.6931)), ('0.A.log_scale', tensor(2.3026)), ('4.weight', tensor([1.])), ('4.bias', tensor([0.])), ('4.running_mean', tensor([0.])), ('4.running_var', tensor([1.])), ('4.num_batches_tracked', tensor(0))])
with torch.no_grad():
yhat_init_2 = model_2(X_all)
# print(yhat_init_2.size())
# print(yhat_init_2)
with torch.no_grad():
plt.scatter(yhat_init_2.numpy(), y_all.numpy())
plt.title("Ground truth versus predicted log odds")
plt.xlabel("log odds")
plt.ylabel("Plant / No plant (1 or 0)")
with torch.no_grad():
print(criterion(yhat_init_2, y_all))
tensor(0.6606)
model_2.to(device)
Sequential( (0): SoftThreshold2( (L): ShiftThenScale(10.0) (A): MyAbs() (S): Sigmoid() (R): ReLU() ) (1): RemoveSimultaneous( (R): ReLU() ) (2): Integrate(sum_dim=(1, 2, 3) (3): Flatten(start_dim=1, end_dim=-1) (4): BatchNorm1d(1, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True) (5): Flatten(start_dim=0, end_dim=-1) )
optimizer_2 = optim.Adam(model_2.parameters(), lr=1e-1, weight_decay=0.0)
outputs, labels = myopt(device, optimizer_2, model_2, criterion, dyn_loader, iter=80)
100% |#########################################| Time: 0:13:06, loss: 0.370897
# with torch.no_grad():
# print(outputs, labels)
with torch.no_grad():
plt.scatter(outputs.numpy(), labels.numpy())
plt.title("Ground truth versus predicted log odds")
plt.xlabel("log odds")
plt.ylabel("Plant / No plant (1 or 0)")
[(name, theta, theta.grad) for name, theta in model_2.named_parameters()]
[('0.L.shift', Parameter containing: tensor(-0.2964, device='cuda:0', requires_grad=True), tensor(0.0182, device='cuda:0')), ('0.L.scale', Parameter containing: tensor(-12.0017, device='cuda:0', requires_grad=True), tensor(2.5145e-05, device='cuda:0')), ('0.A.shift', Parameter containing: tensor(-3.0548, device='cuda:0', requires_grad=True), tensor(4.5884e-07, device='cuda:0')), ('0.A.log_half_width', Parameter containing: tensor(-2.0869, device='cuda:0', requires_grad=True), tensor(5.6932e-08, device='cuda:0')), ('0.A.log_scale', Parameter containing: tensor(3.7463, device='cuda:0', requires_grad=True), tensor(-8.5860e-08, device='cuda:0')), ('4.weight', Parameter containing: tensor([8.5806], device='cuda:0', requires_grad=True), tensor([-0.0274], device='cuda:0')), ('4.bias', Parameter containing: tensor([-0.0011], device='cuda:0', requires_grad=True), tensor([0.0008], device='cuda:0'))]
# We end up with a pretty narrow band for the standard deviation.
with torch.no_grad():
cent = model_2._modules['0'].A.shift.cpu().numpy()
hw = torch.exp(model_2._modules['0'].A.log_half_width).cpu().numpy()
print("Log var center: {:.2f}, log var half width: {:.2f}, sd low, high: {:.2f}, {:.2f}".format(
cent, hw, np.exp(0.5*(cent - hw)), np.exp(0.5*(cent + hw))
))
Log var center: -3.05, log var half width: 0.12, sd low, high: 0.20, 0.23
Enforcing sparsity in NPC¶
Another option for improving model performance is to build into the loss function a penalty for excessive detections in the NPC. We don't just want to separate the two groups, we want to eliminate detections from NPC.
One challenge in doing this is figuring out the correct trade-off between the original criterion and then the "sparsity" we are trying to enforce. Presumably, one could use a reverse-lasso type approach and start with no penalty and then subsequently increase it.
Here the hypothetical output from the model includes the log-odds values as well as the values after the integration step. While we have implemented this approach, we forego that here, since the results above suffice; however, we mention it for the sake of completeness.
# Penalize NPC - finding the right balance is key here
def criterion_with_sparsity(myout, labels, pen=1):
return criterion(myout[0,:], labels) + 0.69 * pen * torch.mean((1 - labels) * myout[1,:])
# def criterion3(myout, labels):
# return torch.mean((1 - labels) * myout[1,:]) + torch.mean(labels / myout[1,:])
Conclusion¶
Herein we have shown that we can formalize the ad hoc method from analysis-01-dynamics
as a neural network model and learn the optimal set of parameters for this model given data.
The data here consists of only 46 cases, which is sufficient for learning the parameters of our model. However, our approach suggests that by employing a larger dataset we could consider larger models, i.e. a deep neural network to capture more complex patterns distinguishing plants from controls.