Introduction¶

Here we consider univariate models for separating corn plants from NPC. From our previous analysis we know that monitorying the velocity in the second principle component is informative for distinguishing between the groups. We will explore various modeling options to give us some intuition when developing more complex models.

Setup¶

Import stuff and determine device we will use for computations

In [1]:
import os
import copy
import h5py
import pandas as pd
import numpy as np
import progressbar
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from matplotlib import pyplot as plt
import plotnine as gg
In [2]:
%matplotlib inline
In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
Out[3]:
device(type='cpu')

Data to use¶

Get RootTracker deployment data.

In [4]:
DATA_LOC = "local"
In [5]:
DATA_DIR = "/usr/local/share/rtphenos-workspace"
if DATA_LOC == "google":
    from google.colab import drive
    drive.mount("/content/drive")
    DATA_DIR = "/content/drive/My Drive/rtphenos-workspace"
In [6]:
rt_deps = pd.read_csv(os.path.join(DATA_DIR, "deployments", "rt_acc1_deps.csv"))
rt_deps.head()
Out[6]:
deployment_id roottracker_serial ram_serial x y bench crop genotype order pot_on_bench pot_on_row rep row deployment_name NPC
0 54 EG00043 D76080D25B7EA126 1.0 11.0 2 Soy PI416937 11 2 11 1 1 ACC Test 1 False
1 54 EG00044 D76080D25B3BD828 1.0 3.0 1 Wheat Comet 3 3 3 1 1 ACC Test 1 False
2 54 EG00045 D76080D25B7F7526 4.0 7.0 1 Cotton PI513384 115 7 7 1 4 ACC Test 1 False
3 54 EG00046 D76080D25B3C9628 2.0 11.0 2 Soy PI416937 47 2 11 1 2 ACC Test 1 False
4 54 EG00047 D76080D25B3A6C28 3.0 8.0 1 Cotton PI513384 80 8 8 1 3 ACC Test 1 False

We want to limit our attention to just corn and NPC devices, since we believe that will have the clearest separation. We also want to remove devices with odd data.

The first definition of bad_roottrackers corresponds to those with odd data coming in on paddles 2 and 3. The second corresponds that and the top 8 devices in terms of integrated, mean-velocity-below-a-threshold, which is balanced between corn and NPC. The filtering we do in our previous analysis is critical for eliminating these NPCs. Here we can do something similar using the RemoveSimultaneous module defined below.

In [7]:
bad_roottrackers = ['EG00237', 'EG00226']
# bad_roottrackers = ['EG00237', 'EG00226', 'EG00228', 'EG00062', 'EG00157', 'EG00096', 'EG00160', 'EG00103', 'EG00159', 'EG00140']
# bad_roottrackers = ['EG00133', 'EG00105', 'EG00226'] # Soy
In [8]:
rt_deps_to_use = rt_deps.query(f"(crop == 'Corn' | crop == 'NPC') & ~(roottracker_serial in {bad_roottrackers})").reset_index(drop=True).reset_index(drop=False)
rt_deps_to_use['rt_serial'] = rt_deps_to_use['roottracker_serial']
rt_deps_to_use.shape
Out[8]:
(46, 17)

Load data¶

We extend the DataSet class from PyTorch to access our data.

In [9]:
class Dynamics(Dataset):

    def __init__(self, rt_deps, base_dir, start=2000, end=4000, slice=3, log_transform=False):
        self.iter_idx = 0
        if isinstance(rt_deps, str):
            self.rt_deps = pd.read_csv(deps_file)
        elif isinstance(rt_deps, pd.DataFrame):
            self.rt_deps = rt_deps.reset_index()
        else:
            raise ValueError("rt_deps must be a path to a csv file or a Dataframe")
        self.base_dir = base_dir
        self.key = 'dyn'
        self.cache = [None for i in range(len(self.rt_deps))]
        self.start = start
        self.end = end
        self.slice = slice
        self.log_transform = log_transform

    def __len__(self):
        return self.rt_deps.shape[0]

    def get_path(self, idx):
        rt_serial = self.rt_deps['roottracker_serial'][idx]
        full_path = os.path.join(self.base_dir, self.rt_deps['deployment_name'][idx], f"{rt_serial}.h5")
        if not os.path.exists(full_path):
            raise Exception(f"Path {full_path} does not exist")
        return full_path

    def __getitem__(self, idx):
        out = None
        if self.cache[idx] is None:
            with h5py.File(self.get_path(idx), 'r') as h5r:
                #x = torch.log(x + 1e-10)
                temp = np.array(
                        h5r[self.key][self.start:self.end,:,:,:]
                    )[:,:,:,self.slice]
                if self.log_transform:
                    temp = np.log(temp + 1e-10)
                self.cache[idx] = torch.tensor(temp).float()
        return self.cache[idx], torch.tensor(~self.rt_deps['NPC'][idx], dtype=torch.float32)

    def lookup(self, rt_serial):
        out = self.rt_deps.query(f"roottracker_serial == '{rt_serial}'")
        nout = len(out)
        if nout == 1:
            return out.index[0]
        elif nout > 1:
            raise Exception("Found multiple")
        else:
            raise Exception("Found none")

    def __iter__(self):
        self.iter_idx = 0
        return self

    def __next__(self):
        if self.iter_idx < len(self):
            out = self[self.iter_idx]
            self.iter_idx += 1
            return out
        raise StopIteration
In [10]:
# Should be balanced
rt_deps_to_use['crop'].value_counts()
Out[10]:
crop
Corn    23
NPC     23
Name: count, dtype: int64

Data Loaders and data¶

VAR_TO_USE=3 is mean velocity, while VAR_TO_USE=5 is standard deviation of velocity, both for the second component. Both of these covariates can be used to separate the corn plants from the NPC. However, VAR_TO_USE=5 has the disadvantage that it flags lots of points as anomalous, so it seems this works well on an aggregate level, but less well if what we are after is individual root toutches.

Note: we use start=2000 and end=4000 below, which is not the whole time series. Results may change here if you make use of all the data (e.g. start=0 and end=8043).

In [11]:
VAR_TO_USE, LOG_TRANSFORM = 3, False
# VAR_TO_USE, LOG_TRANSFORM = 5, True
BATCH_SIZE = rt_deps_to_use.shape[0]
In [12]:
dyn_data = Dynamics(rt_deps_to_use, os.path.join(DATA_DIR, "dynamics"), start=2000, end=4000, slice=VAR_TO_USE, log_transform=LOG_TRANSFORM)
In [13]:
dyn_loader = DataLoader(dyn_data, batch_size=BATCH_SIZE, shuffle=False)
In [14]:
y_all = torch.stack([d[1] for d in dyn_data])
In [15]:
X_all = torch.stack([d[0] for d in dyn_data])
In [16]:
X_all.shape
Out[16]:
torch.Size([46, 2000, 12, 22])

Models¶

Here we build the pieces that make up the models and define the models. The basic idea, based on our previous data exploration is as follows:

  1. Soft Treshold
  2. Integrate
  3. Shift and scale to get things right on the log odd scale

Modules¶

In [17]:
class ShiftThenScale(nn.Module):

    def __init__(self, scale_mag=1.):
        super().__init__()
        self.shift = nn.Parameter(torch.randn(()))
        self.scale = nn.Parameter(torch.randn(()))
        self.scale_mag = scale_mag

    def forward(self, x):
        return self.scale_mag * self.scale * (x - self.shift)

    def __repr__(self):
        return f"ShiftThenScale({self.scale_mag})"
In [18]:
class SoftThreshold(nn.Module):

    def __init__(self, scale_mag=1.0):
        super().__init__()
        self.L = ShiftThenScale(scale_mag=scale_mag)
        self.S = nn.Sigmoid()
        # self.R = nn.ReLU()

    def forward(self, x):
        x = self.L(x)
        x = self.S(x)
        # x = 2.0 * (self.R(x - 0.5))
        return x
In [19]:
class Integrate(nn.Module):
    def __init__(self, sum_dim):
        super().__init__()
        self.sum_dim = sum_dim

    def forward(self, x):
        n = len(x.size())
        x = torch.mean(x, dim=self.sum_dim, keepdim=True)
        return x

    def __repr__(self):
        return (f"Integrate(sum_dim={self.sum_dim})")
In [20]:
class Unsqueeze(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.unsqueeze(4)
In [21]:
class Standardize(nn.Module):

    def __init__(self, return_both=True, use_sqrt=False):
        super().__init__()
        self.return_both = return_both
        self.use_sqrt = use_sqrt

    def forward(self, x):
        xs = x.squeeze()
        if self.use_sqrt:
            xs = torch.sqrt(xs + 1e-10)
        # Median seems like a more intuitive approach, but mean seems better in terms of learning
        # m = torch.mean(xs) # but this makes things more sensitive to all data
        # s = torch.std(xs) # you don't gain anything by squashing things at 50% here
        m = torch.median(xs) # ensures balance, and could be used with known imbalance
        s = torch.sqrt(torch.mean(torch.square(xs - m)))
        y = (xs - m) / s
        if self.return_both:
            return torch.stack([y, xs])
        else:
            return y
In [22]:
class LinearAndCopy(nn.Module):

    def __init__(self, return_both=False):
        super().__init__()
        self.return_both = return_both
        self.L1 = nn.Linear(1,1)

    def forward(self, x):
        xs = x.squeeze()
        y = self.L1(x)
        ys = y.squeeze()
        if self.return_both:
            return torch.stack([ys, xs])
        else:
            return ys
In [23]:
class RemoveSimultaneous(nn.Module):
    def __init__(self):
        super().__init__()
        self.R = nn.ReLU()

    def forward(self, x):
        shp = x.size()
        npads = shp[2]
        y = x.new(shp)
        for i in range(npads):
            others = [True] * npads
            others[i] = False
            val, idc = torch.max(x[:,:,others,:], dim=2)
            y[:,:,i,:] = x[:,:,i,:] - val
        y = self.R(y)
        return y

Model definitions¶

Below we define the models. As mentioned previously, the basic approach is:

  1. Soft Treshold
  2. Integrate
  3. Shift and scale to get things right on the log odd scale

Based on some numerical experience there are a few things to consider with (1) and (3).

Regarding (1): above, we define ShiftThenScale as an equivalent approach to nn.Linear, but under a different parameterization. ShiftThenScale is advantageous in that we can easily translate the work from analysis-01-dynamics to its parameterization, making it easy to understand its intialization and final output. Of course, nn.Linear should be equivalent, and it does lead to the same behavior in terms of the local minima found.

Regarding (3): have a final layer determine the correct shift and scale to return to the log-odds scale is problematic. In particular, there seems to be a stable point (or almost stable point, which I assume is not a local minima) when the log-odds cluster at very similar values. Since we are trying to separate balanced data, it seems reasonable to center the data at the median, which is like ensuring that the number of false positives is the same as the number of false negatives. Having centered the data, it seems reasonable to standardize the data as well, which eliminates the issue of the log-odds values taking on very similar values. In experimentation, this works well. Standardizing the data is very similar to batch normalization. Wikipedia suggests that there is some controversy over why batch normalization is a good idea, but here at least, it seems to be effective because it eliminates or makes it easier to pass through a stationary point.

Below, we will work through all of these issues via case study.

From our experimentation, when using VAR_TO_USE=3, it seems that there are two local modes --- 1) low momentum or 2) high momentum in the second principal component, which is not too surprising. Below, you can set whether to initalize to find the "low" or "high" mode. If VAR_TO_USE=5, then there is a mode when the log standard deviation is above a certain level.

In [24]:
MOM2_LOW = True and (VAR_TO_USE == 3)
MOM2_LOW
Out[24]:
True
In [25]:
# For balanced data!
def make_model_1(return_both=False):
    out = nn.Sequential(
        SoftThreshold(),
        Integrate(sum_dim = (1, 2, 3)),
        Standardize(use_sqrt=False, return_both=return_both),
    )
    return out
In [26]:
def make_model_2(return_both=False):
    out = nn.Sequential(
        Unsqueeze(),
        nn.Linear(1,1),
        nn.Sigmoid(),
        Integrate(sum_dim = (1, 2, 3)),
        Standardize(return_both=return_both)
    )
    return out
In [27]:
def make_model_3(return_both=False):
    out = nn.Sequential(
        Unsqueeze(),
        nn.Linear(1,1),
        nn.Sigmoid(),
        Integrate(sum_dim = (1, 2, 3)),
        nn.Linear(1,1),
        nn.Flatten(start_dim=0)
    )
    return out
In [28]:
def make_model_4(return_both=False):
    out = nn.Sequential(
        Unsqueeze(),
        nn.Linear(1,1),
        nn.Sigmoid(),
        Integrate(sum_dim = (1, 2, 3)),
        nn.Flatten(start_dim=1),
        nn.BatchNorm1d(num_features=1, momentum=0.0),
        nn.Flatten(start_dim=0)
    )
    return out

Helper functions¶

In [29]:
def myopt(device, optimizer, model, criterion, data_loader, iter=50):
    widgets = [
        progressbar.Percentage(), " ",
        progressbar.Bar(), " ",
        progressbar.ETA(), ", ",
        progressbar.Variable('loss', precision=6),
    ]
    with progressbar.ProgressBar(max_value=iter, widgets=widgets) as bar:
        for epoch in range(iter):  # loop over the dataset multiple times
            running_loss = 0.0
            for i, data in enumerate(dyn_loader, 0): # There should just be 1 batch
                inputs_cpu, labels_cpu = data
                labels_cpu = 1*labels_cpu
                inputs = inputs_cpu.to(device)
                labels = labels_cpu.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                outputs_cpu = outputs.cpu()
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
            bar.update(epoch, loss=loss.item())
    return outputs_cpu, labels_cpu

Example - Model 1¶

Model 1 uses ShiftThenScale when constructing the soft threshold. Here we use that convenient parameterization and our knowledge of reasonable paramters from analysis-01-dynamics to demonstrate that we can nearly seprate the two groups.

In [30]:
example_model = make_model_1(return_both=False)
In [31]:
print(f"Model structure: {example_model}\n\n")
Model structure: Sequential(
  (0): SoftThreshold(
    (L): ShiftThenScale(1.0)
    (S): Sigmoid()
  )
  (1): Integrate(sum_dim=(1, 2, 3))
  (2): Standardize()
)


Here are the random initiazlized values.

In [32]:
example_model.state_dict()
Out[32]:
OrderedDict([('0.L.shift', tensor(-0.6609)), ('0.L.scale', tensor(-1.2181))])

Based on our previous work, we know this model should work well. Here are some parameter values that show this is the case:

In [33]:
if MOM2_LOW:
    # Low mom 2
    model_1_defaults = {
        # Momentum less than -0.25
        '0.L.scale': torch.tensor(-100.),
        '0.L.shift': torch.tensor(-0.25),
    }
else:
    # High mom 2 OR High standard deviation of mom 2 (this works in either case)
    model_1_defaults = {
        '0.L.scale': torch.tensor(5.),
        '0.L.shift': torch.tensor(0.82),
    }
# elif VAR_TO_USE == 5:
#     # less variation in mom_2 indicative of NPC / more variation in mom_2 indicative of corn
#     # Perhaps not surprising since we saw that this was true for global mom_2 variance in the previous analysis.
#     model_1_defaults = {
#         # Log variance less than 1
#         '0.L.scale': torch.tensor(-10.0),
#         '0.L.shift': torch.tensor(1.0),
#     }
In [34]:
example_model.load_state_dict(model_1_defaults)
Out[34]:
<All keys matched successfully>
In [35]:
example_model.state_dict()
Out[35]:
OrderedDict([('0.L.shift', tensor(-0.2500)), ('0.L.scale', tensor(-100.))])

Now, let's compute the log odds values using the model with parameters chosen as above, plot them, and then compute the loss.

In [36]:
with torch.no_grad():
    example_y_hat = example_model(X_all)
In [37]:
with torch.no_grad():
    plt.scatter(example_y_hat.numpy(), y_all.numpy())
    plt.title("Ground truth vs. log odds (M1)")
    plt.xlabel("log odds")
    plt.ylabel("Plant / No plant (1 or 0)")
No description has been provided for this image
In [38]:
example_criterion = nn.BCEWithLogitsLoss()
In [39]:
with torch.no_grad():
    outputs = example_model(X_all)
    loss = example_criterion(outputs, y_all)
    print(loss.item())
0.3672024607658386

Inference¶

Now, let's show that we can learn the parameters that separate the two groups. As mentioned previously, when considering VAR_TO_USE=3, there is one catch: there are two modes, one corresponding to low momentum and the other correspond to high momentum (both in the second principal component). The initialization matters in terms of which mode we converge to. To make that clear, we have artificially seeded the random number generator to find each according to MOM2_LOW.

In [40]:
criterion = nn.BCEWithLogitsLoss()

Model 1¶

Model 1 uses ShiftThenScale and standardizes (by median) the last layer.

In [41]:
if MOM2_LOW:
    torch.manual_seed(9834) # Low mom 2
else:
    torch.manual_seed(1234) # High mom 2
In [42]:
model_1 = make_model_1()
In [43]:
[(name, theta, theta.grad) for name, theta in model_1.named_parameters()]
Out[43]:
[('0.L.shift',
  Parameter containing:
  tensor(-0.1109, requires_grad=True),
  None),
 ('0.L.scale',
  Parameter containing:
  tensor(-0.1387, requires_grad=True),
  None)]

Here, we plot the log-odds values after the initilization.

In [44]:
with torch.no_grad():
    y_hat_init_1 = model_1(X_all)
In [45]:
with torch.no_grad():
    plt.scatter(y_hat_init_1.numpy(), y_all.numpy())
    plt.title("Ground truth vs. log odds (M1 at initialization)")
    plt.xlabel("log odds")
    plt.ylabel("Plant / No plant (1 or 0)")
No description has been provided for this image
In [46]:
model_1.to(device)
Out[46]:
Sequential(
  (0): SoftThreshold(
    (L): ShiftThenScale(1.0)
    (S): Sigmoid()
  )
  (1): Integrate(sum_dim=(1, 2, 3))
  (2): Standardize()
)

Let's learn the parameters of interest:

In [47]:
optimizer_1 = optim.Adam(model_1.parameters(), lr=1e-1, weight_decay=0.0)
In [48]:
ITER = 200

widgets = [
    progressbar.Percentage(), " ",
    progressbar.Bar(), " ",
    progressbar.ETA(), ", ",
    progressbar.Variable('loss', precision=6),
]

with progressbar.ProgressBar(max_value=ITER, widgets=widgets) as bar:
    for epoch in range(ITER):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(dyn_loader, 0): # There should just be 1 batch
            inputs_cpu, labels_cpu = data
            labels_cpu = 1*labels_cpu
            inputs = inputs_cpu.to(device)
            labels = labels_cpu.to(device)
            optimizer_1.zero_grad()
            outputs = model_1(inputs)
            outputs_cpu = outputs.cpu()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer_1.step()
        bar.update(epoch, loss=loss.item())
100% |#########################################| Time:  0:00:26, loss: 0.368255

Finally, let's plot the log-odds values against ground truth and the print out the parameters.

In [49]:
with torch.no_grad():
    plt.scatter(outputs_cpu.numpy(), labels_cpu.numpy())
    plt.title("Ground truth vs. log odds (M1 after opt)")
    plt.xlabel("log odds")
    plt.ylabel("Plant / No plant (1 or 0)")
No description has been provided for this image

The parameters:

In [50]:
[(name, theta, theta.grad) for name, theta in model_1.named_parameters()]
Out[50]:
[('0.L.shift',
  Parameter containing:
  tensor(-0.5213, requires_grad=True),
  tensor(-5.5790e-05)),
 ('0.L.scale',
  Parameter containing:
  tensor(-7.6106, requires_grad=True),
  tensor(0.0013))]

Model 2¶

Model 2 uses nn.Linear when defining the soft threshold. As before, it can converge to either mode. Below, we do the same as above. Show the log-odds values at initialization and then after optimizing.

In [51]:
if MOM2_LOW:
    torch.manual_seed(1234) # Low Mom2
else:
    torch.manual_seed(3456) # High Mom2
In [52]:
model_2 = make_model_2()
In [53]:
[(name, theta, theta.grad) for name, theta in model_2.named_parameters()]
Out[53]:
[('1.weight',
  Parameter containing:
  tensor([[-0.9420]], requires_grad=True),
  None),
 ('1.bias',
  Parameter containing:
  tensor([-0.1962], requires_grad=True),
  None)]
In [54]:
with torch.no_grad():
    yhat_init_2 = model_2(X_all)
    plt.scatter(yhat_init_2.numpy(), y_all.numpy())
    plt.title("Ground truth vs. log odds (M2 at initialization)")
    plt.xlabel("log odds")
    plt.ylabel("Plant / No plant (1 or 0)")
No description has been provided for this image
In [55]:
model_2.to(device)
Out[55]:
Sequential(
  (0): Unsqueeze()
  (1): Linear(in_features=1, out_features=1, bias=True)
  (2): Sigmoid()
  (3): Integrate(sum_dim=(1, 2, 3))
  (4): Standardize()
)
In [56]:
# You can use lr=1e-1 to start, but to get lower loss, you then need to switch to a lower learning rate
optimizer_2 = optim.Adam(model_2.parameters(), lr=1e-1, weight_decay=0.0)
In [57]:
outputs, labels = myopt(device, optimizer_2, model_2, criterion, dyn_loader, iter=100)
100% |#########################################| Time:  0:00:36, loss: 0.370083
In [58]:
with torch.no_grad():
    plt.scatter(outputs.numpy(), labels.numpy())
    plt.title("Ground truth vs. log odds (M2 after opt)")
    plt.xlabel("log odds")
    plt.ylabel("Plant / No plant (1 or 0)")
No description has been provided for this image
In [59]:
[(name, theta, theta.grad) for name, theta in model_2.named_parameters()]
Out[59]:
[('1.weight',
  Parameter containing:
  tensor([[-6.5278]], requires_grad=True),
  tensor([[0.0020]])),
 ('1.bias',
  Parameter containing:
  tensor([-3.6591], requires_grad=True),
  tensor([0.0006]))]

Model 3¶

Model 3 has a nn.Linear layer to try to learn the final layer that creates the log-odds values. We get to a flat part of the parameter space where the log-odds values are all clustered at similar values.

In [60]:
torch.manual_seed(1234)
Out[60]:
<torch._C.Generator at 0x7fcafd998650>
In [61]:
model_3 = make_model_3()
In [62]:
[(name, theta, theta.grad) for name, theta in model_3.named_parameters()]
Out[62]:
[('1.weight',
  Parameter containing:
  tensor([[-0.9420]], requires_grad=True),
  None),
 ('1.bias',
  Parameter containing:
  tensor([-0.1962], requires_grad=True),
  None),
 ('4.weight',
  Parameter containing:
  tensor([[-0.4803]], requires_grad=True),
  None),
 ('4.bias',
  Parameter containing:
  tensor([-0.2667], requires_grad=True),
  None)]
In [63]:
with torch.no_grad():
    yhat_init_3 = model_3(X_all)
    plt.scatter(yhat_init_3.numpy(), y_all.numpy())
    plt.title("Ground truth vs. log odds (M3 at initialization)")
    plt.xlabel("log odds")
    plt.ylabel("Plant / No plant (1 or 0)")
No description has been provided for this image
In [64]:
model_3.to(device)
Out[64]:
Sequential(
  (0): Unsqueeze()
  (1): Linear(in_features=1, out_features=1, bias=True)
  (2): Sigmoid()
  (3): Integrate(sum_dim=(1, 2, 3))
  (4): Linear(in_features=1, out_features=1, bias=True)
  (5): Flatten(start_dim=0, end_dim=-1)
)
In [65]:
# If you start at a good initialization, this can work.  But you seem to get stuck in a flat place more often than not.
optimizer_3 = optim.Adam(model_3.parameters(), lr=1e-2, weight_decay=0.0)
In [66]:
outputs, labels = myopt(device, optimizer_3, model_3, criterion, dyn_loader, iter=200)
100% |#########################################| Time:  0:01:12, loss: 0.693142
In [67]:
with torch.no_grad():
    plt.scatter(outputs.numpy(), labels.numpy())
    plt.title("Ground truth vs. log odds (M3 after opt)")
    plt.xlabel("log odds")
    plt.ylabel("Plant / No plant (1 or 0)")
No description has been provided for this image
In [68]:
# We have reached a very, very flat place.
[(name, theta, theta.grad) for name, theta in model_3.named_parameters()]
Out[68]:
[('1.weight',
  Parameter containing:
  tensor([[0.3495]], requires_grad=True),
  tensor([[-1.3065e-05]])),
 ('1.bias',
  Parameter containing:
  tensor([-0.4900], requires_grad=True),
  tensor([-1.6631e-06])),
 ('4.weight',
  Parameter containing:
  tensor([[-0.1585]], requires_grad=True),
  tensor([[2.5449e-05]])),
 ('4.bias',
  Parameter containing:
  tensor([0.0601], requires_grad=True),
  tensor([-1.7222e-05]))]

Model 4¶

Above, we saw that standardization worked well for getting us to the log-odds scale, but that a plain linear layer did not. Model 4 employs batch normalization, which effectively combines the two: first, you standardize the batch (or in our case the whole, balanced data set), then you apply a linear transformation. This works well. In fact, more often than not it seems to help us get to the global minima.

In [69]:
model_4 = make_model_4()
In [70]:
model_4_init = model_4.state_dict()
model_4_init
Out[70]:
OrderedDict([('1.weight', tensor([[0.3731]])),
             ('1.bias', tensor([0.6514])),
             ('5.weight', tensor([1.])),
             ('5.bias', tensor([0.])),
             ('5.running_mean', tensor([0.])),
             ('5.running_var', tensor([1.])),
             ('5.num_batches_tracked', tensor(0))])
In [71]:
with torch.no_grad():
    yhat_init_4 = model_4(X_all)
In [72]:
with torch.no_grad():
    plt.scatter(yhat_init_4.numpy(), y_all.numpy())
    plt.title("Ground truth vs. log odds (M4 at initialization)")
    plt.xlabel("log odds")
    plt.ylabel("Plant / No plant (1 or 0)")
No description has been provided for this image
In [73]:
model_4.to(device)
Out[73]:
Sequential(
  (0): Unsqueeze()
  (1): Linear(in_features=1, out_features=1, bias=True)
  (2): Sigmoid()
  (3): Integrate(sum_dim=(1, 2, 3))
  (4): Flatten(start_dim=1, end_dim=-1)
  (5): BatchNorm1d(1, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True)
  (6): Flatten(start_dim=0, end_dim=-1)
)
In [74]:
optimizer_4 = optim.Adam(model_4.parameters(), lr=1e-1, weight_decay=0.0)
In [75]:
outputs, labels = myopt(device, optimizer_4, model_4, criterion, dyn_loader, iter=50)
100% |#########################################| Time:  0:00:18, loss: 0.194433
In [76]:
# with torch.no_grad():
#     print(outputs, labels)
In [77]:
with torch.no_grad():
    plt.scatter(outputs.numpy(), labels.numpy())
    plt.title("Ground truth vs. log odds (M4 after opt)")
    plt.xlabel("log odds")
    plt.ylabel("Plant / No plant (1 or 0)")
No description has been provided for this image
In [78]:
[(name, theta, theta.grad) for name, theta in model_4.named_parameters()]
Out[78]:
[('1.weight',
  Parameter containing:
  tensor([[-4.6989]], requires_grad=True),
  tensor([[0.0288]])),
 ('1.bias',
  Parameter containing:
  tensor([-2.0223], requires_grad=True),
  tensor([0.0157])),
 ('5.weight',
  Parameter containing:
  tensor([5.0479], requires_grad=True),
  tensor([-0.0347])),
 ('5.bias',
  Parameter containing:
  tensor([0.2015], requires_grad=True),
  tensor([-0.0011]))]

Conclusion¶

Herein, we have learned several key things. First, it is possible to approximate the results from analysis-01-dyanmics using a soft threshold function. Second, one can learn the optimal parameters, but one must be careful about getting the output on the correct scale. Third, in regards to that, standardization or batch normalization is critical for returning to the log-odds scale, since a linear layer alone runs into problems. Fourth, there may be multiple local minima and the initialization is critical to the mode that is found.

In [ ]: