xforms with Tensorflow

This notebook demonstrates the use of xforms for generating models that are easy to reload, using tensorflow as the fitter. The process is very similar to normal workflows, except when specifying the fitter, fit_tf rather than fit_basic is used.

[1]:
import logging
from pathlib import Path

import nems0.db as db
import nems0.modelspec as ms
import nems0.recording as recording
import nems0.uri
import nems0.xforms as xforms
[nems.configs.defaults INFO] Saving log messages to /tmp/nems\NEMS 2020-05-28 133629.log

Configuration

[2]:
# get the data and results paths
results_dir = nems.get_setting('NEMS_RESULTS_DIR')
signals_dir = nems.get_setting('NEMS_RECORDINGS_DIR')
[3]:
# download some demo data
recording.get_demo_recordings(signals_dir)
datafile = Path(signals_dir) / 'TAR010c-18-1.pkl'

Data Loading and Preprocessing

[4]:
load_command = 'nems.demo.loaders.demo_loader'
expt_id = 'TAR010c'
batch = 271
cell_id = 'TAR010c-18-1'
[5]:
modelspec_name = 'dlog-wc.18x1.g-fir.1x15-lvl.1-dexp.1'

Generate the Modelspec

Each item in the xform spec is a call to a function along with the arguments for that function.

[6]:
xfspec = []
[7]:
# load from external format
xfspec.append(['nems.xforms.load_recording_wrapper',
               {'load_command': load_command,
                'exptid': expt_id,
                'datafile': str(datafile)
               }])
[8]:
# split the data into est and val
xfspec.append(['nems.xforms.split_by_occurrence_counts',
               {'epoch_regex': '^STIM_'}])
[9]:
xfspec.append(['nems.xforms.average_away_stim_occurrences', {}])
[10]:
meta = {'cellid': cell_id, 'batch': batch, 'modelname': modelspec_name, 'recording': expt_id}

xfspec.append(['nems.xforms.init_from_keywords',
               {'keywordstring': modelspec_name,
                'meta': meta
               }])
[11]:
# init, then fit
xfspec.append(['nems.tf.cnnlink.fit_tf_init', {}])
xfspec.append(['nems.tf.cnnlink.fit_tf', {}])
[12]:
xfspec.append(['nems.xforms.predict', {}])
[13]:
# test prediction then visualize
xfspec.append(['nems.analysis.api.standard_correlation', {},
               ['est', 'val', 'modelspec', 'rec'], ['modelspec']])
xfspec.append(['nems.xforms.plot_summary', {}])

Run the Analysis

[14]:
ctx = {}
for xfa in xfspec:
    ctx = xforms.evaluate_step(xfa, ctx)
[nems.xforms INFO] Evaluating: nems.xforms.load_recording_wrapper
[nems.xforms INFO] Loading cached file C:\Users\Alex\PycharmProjects\NEMS\recordings\TAR010c_afb264b3db970ec890e04c727e612c1cbfaced62.tgz
[nems.xforms INFO] Evaluating: nems.xforms.split_by_occurrence_counts
[nems.xforms INFO] Evaluating: nems.xforms.average_away_stim_occurrences
[nems.xforms INFO] Evaluating: nems.xforms.init_from_keywords
[nems.initializers INFO] kw: dlog
[nems.initializers INFO] kw: wc.18x1.g
[nems.initializers INFO] kw: fir.1x15
[nems.initializers INFO] kw: lvl.1
[nems.initializers INFO] kw: dexp.1
[nems.initializers INFO] Setting modelspec[0] input to stim
[nems.xforms INFO] Evaluating: nems.tf.cnnlink.fit_tf_init
[nems.tf.cnnlink INFO] target_module: ['levelshift', 'relu'] found at modelspec[3].
[nems.tf.cnnlink INFO] Mod 3 (nems.modules.levelshift.levelshift) fixing level to resp mean 0.207
[nems.tf.cnnlink INFO] resp has 1 channels
[nems.tf.cnnlink INFO] nems.modules.nonlinearity.dlog
nems.modules.weight_channels.gaussian
nems.modules.fir.basic
nems.modules.levelshift.levelshift
[nems.tf.cnnlink INFO] seed for this fit: 100
[nems.tf.cnnlink INFO] feat_dims: (90, 550, 18)
[nems.tf.cnnlink INFO] data_dims: (90, 550, 1)
[nems.tf.cnnlink INFO] rand seed for intialization: 100
[nems.modelspec INFO] Modelspec2tf: nems.modules.nonlinearity.dlog
[tensorflow WARNING] From C:\Users\Alex\Anaconda3\envs\nems-gpu\lib\site-packages\tensorflow\python\ops\resource_variable_ops.py:1666: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
[nems.modelspec INFO] Modelspec2tf: nems.modules.weight_channels.gaussian
[nems.modelspec INFO] Modelspec2tf: nems.modules.fir.basic
[nems.modelspec INFO] Modelspec2tf: nems.modules.levelshift.levelshift
[nems.tf.cnn INFO] Initializing net: setting output, loss, optimizer, globals, tf session
[nems.tf.cnn INFO] Training with batch_size=90, LR=0.01, max_iter=2000, early_stopping_steps=5, early_stopping_tolerance=0.0005, optimizer=Adam.
[nems.tf.cnn INFO] Initial loss=0.7398484
[nems.tf.cnn INFO] 0000 loss=0.772475, delta=+0.000000
[nems.tf.cnn INFO] 0030 loss=0.632428, delta=-0.140048
[nems.tf.cnn INFO] 0060 loss=0.611400, delta=-0.021027
[nems.tf.cnn INFO] 0090 loss=0.607508, delta=-0.003892
[nems.tf.cnn INFO] 0120 loss=0.605566, delta=-0.001942
[nems.tf.cnn INFO] 0150 loss=0.604670, delta=-0.000896
[nems.tf.cnn INFO] 0180 loss=0.604249, delta=-0.000421
[nems.tf.cnn INFO] 0210 loss=0.604018, delta=-0.000231
[nems.tf.cnn INFO] 0240 loss=0.603880, delta=-0.000137
[nems.tf.cnn INFO] 0270 loss=0.603795, delta=-0.000085
[nems.tf.cnn INFO] 0300 loss=0.603741, delta=-0.000055
[nems.tf.cnn INFO] 5 epochs without significant improvement, stopping early!
[tensorflow INFO] Restoring parameters from C:\Users\Alex\PycharmProjects\NEMS\results\271\TAR010c-18-1\TAR010c.dlog_wc.18x1.g_fir.1x15_lvl.1_dexp.1.unknown_fitter.2020-05-28T203635\seed100-model.ckpt
[nems.tf.cnnlink INFO] tf2modelspec: nems.modules.nonlinearity.dlog
[nems.tf.cnnlink INFO] tf2modelspec: nems.modules.weight_channels.gaussian
[nems.tf.cnnlink INFO] tf2modelspec: nems.modules.fir.basic
[nems.tf.cnnlink INFO] tf2modelspec: nems.modules.levelshift.levelshift
[nems.tf.cnnlink INFO] starting eval_tf. evaluate nems model
[nems.tf.cnnlink INFO] saving nems pred
[nems.tf.cnnlink INFO] generating TF input matrix
[nems.tf.cnnlink INFO] feat_dims: [1, 49500, 18]
[nems.tf.cnnlink INFO] data_dims: [1, 49500, 1]
[nems.modelspec INFO] Modelspec2tf: nems.modules.nonlinearity.dlog
[nems.modelspec INFO] Modelspec2tf: nems.modules.weight_channels.gaussian
[nems.modelspec INFO] Modelspec2tf: nems.modules.fir.basic
[nems.modelspec INFO] Modelspec2tf: nems.modules.levelshift.levelshift
[nems.tf.cnn INFO] Initializing net: setting output, loss, optimizer, globals, tf session
[nems.tf.cnnlink INFO] Mean difference between NEMS and TF model pred: 5.000445e-08
[nems.initializers INFO] Found module 4 (double_exponential) for subset prefit
[nems.initializers INFO] Fit: [4]
[nems.initializers INFO] Freeze: [0 1 2 3]
[nems.initializers INFO] Exclude: []
[nems.initializers INFO] Freezing phi for module 0 (nems.modules.nonlinearity.dlog)
[nems.initializers INFO] Freezing phi for module 1 (nems.modules.weight_channels.gaussian)
[nems.initializers INFO] Freezing phi for module 2 (nems.modules.fir.basic)
[nems.initializers INFO] Freezing phi for module 3 (nems.modules.levelshift.levelshift)
[nems.analysis.fit_basic INFO] Data len pre-mask: 49500
[nems.analysis.fit_basic INFO] Data len post-mask: 49500
[nems.modelspec INFO] Freezing fast rec at start=4
[nems.fitters.fitter INFO] options {'ftol': 0.0001, 'maxiter': 700, 'maxfun': 7000}
[nems.fitters.fitter INFO] Start sigma: [0.6736 0.     1.3346 0.2074]
[nems.analysis.cost_functions INFO] Eval #100. E=0.863631
[nems.fitters.fitter INFO] Starting error: 0.909876 -- Final error: 0.863109
[nems.fitters.fitter INFO] Final sigma: [2.6557 0.0885 0.7428 0.8146]
[nems.analysis.fit_basic INFO] Delta error: 0.909876 - 0.863109 = -4.676691e-02
[nems.tf.cnnlink INFO] finished fit_tf_init, fit_idx=1/1
[nems.xforms INFO] Evaluating: nems.tf.cnnlink.fit_tf
[nems.tf.cnnlink INFO] seed for this fit: 50
[nems.tf.cnnlink INFO] feat_dims: (90, 550, 18)
[nems.tf.cnnlink INFO] data_dims: (90, 550, 1)
[nems.tf.cnnlink INFO] rand seed for intialization: 50
[nems.modelspec INFO] Modelspec2tf: nems.modules.nonlinearity.dlog
[nems.modelspec INFO] Modelspec2tf: nems.modules.weight_channels.gaussian
[nems.modelspec INFO] Modelspec2tf: nems.modules.fir.basic
[nems.modelspec INFO] Modelspec2tf: nems.modules.levelshift.levelshift
[nems.modelspec INFO] Modelspec2tf: nems.modules.nonlinearity.double_exponential
[nems.tf.cnn INFO] Initializing net: setting output, loss, optimizer, globals, tf session
[nems.tf.cnn INFO] Training with batch_size=90, LR=0.01, max_iter=1000, early_stopping_steps=5, early_stopping_tolerance=0.0005, optimizer=Adam.
[nems.tf.cnn INFO] Initial loss=0.5732868
[nems.tf.cnn INFO] 0000 loss=0.642390, delta=+0.000000
[nems.tf.cnn INFO] 0030 loss=0.569176, delta=-0.073214
[nems.tf.cnn INFO] 0060 loss=0.567734, delta=-0.001442
[nems.tf.cnn INFO] 0090 loss=0.566385, delta=-0.001349
[nems.tf.cnn INFO] 0120 loss=0.565158, delta=-0.001227
[nems.tf.cnn INFO] 0150 loss=0.564020, delta=-0.001138
[nems.tf.cnn INFO] 0180 loss=0.562967, delta=-0.001053
[nems.tf.cnn INFO] 0210 loss=0.562005, delta=-0.000962
[nems.tf.cnn INFO] 0240 loss=0.561137, delta=-0.000869
[nems.tf.cnn INFO] 0270 loss=0.560357, delta=-0.000780
[nems.tf.cnn INFO] 0300 loss=0.559658, delta=-0.000698
[nems.tf.cnn INFO] 0330 loss=0.559031, delta=-0.000627
[nems.tf.cnn INFO] 0360 loss=0.558464, delta=-0.000567
[nems.tf.cnn INFO] 0390 loss=0.557945, delta=-0.000519
[nems.tf.cnn INFO] 0420 loss=0.690692, delta=+0.132747
[nems.tf.cnn INFO] 0450 loss=0.568217, delta=+0.010271
[nems.tf.cnn INFO] 0480 loss=0.559294, delta=+0.001349
[nems.tf.cnn INFO] 0510 loss=0.558653, delta=+0.000708
[nems.tf.cnn INFO] 0540 loss=0.558198, delta=+0.000253
[nems.tf.cnn INFO] 0570 loss=0.557808, delta=-0.000137
[nems.tf.cnn INFO] 0600 loss=0.557458, delta=-0.000350
[nems.tf.cnn INFO] 0630 loss=0.557137, delta=-0.000321
[nems.tf.cnn INFO] 0660 loss=0.556841, delta=-0.000296
[nems.tf.cnn INFO] 5 epochs without significant improvement, stopping early!
[tensorflow INFO] Restoring parameters from C:\Users\Alex\PycharmProjects\NEMS\results\271\TAR010c-18-1\TAR010c.dlog_wc.18x1.g_fir.1x15_lvl.1_dexp.1.unknown_fitter.2020-05-28T203635\seed50-model.ckpt
[nems.tf.cnnlink INFO] tf2modelspec: nems.modules.nonlinearity.dlog
[nems.tf.cnnlink INFO] tf2modelspec: nems.modules.weight_channels.gaussian
[nems.tf.cnnlink INFO] tf2modelspec: nems.modules.fir.basic
[nems.tf.cnnlink INFO] tf2modelspec: nems.modules.levelshift.levelshift
[nems.tf.cnnlink INFO] tf2modelspec: nems.modules.nonlinearity.double_exponential
[nems.tf.cnnlink INFO] starting eval_tf. evaluate nems model
[nems.tf.cnnlink INFO] saving nems pred
[nems.tf.cnnlink INFO] generating TF input matrix
[nems.tf.cnnlink INFO] feat_dims: [1, 49500, 18]
[nems.tf.cnnlink INFO] data_dims: [1, 49500, 1]
[nems.modelspec INFO] Modelspec2tf: nems.modules.nonlinearity.dlog
[nems.modelspec INFO] Modelspec2tf: nems.modules.weight_channels.gaussian
[nems.modelspec INFO] Modelspec2tf: nems.modules.fir.basic
[nems.modelspec INFO] Modelspec2tf: nems.modules.levelshift.levelshift
[nems.modelspec INFO] Modelspec2tf: nems.modules.nonlinearity.double_exponential
[nems.tf.cnn INFO] Initializing net: setting output, loss, optimizer, globals, tf session
[nems.tf.cnnlink INFO] Mean difference between NEMS and TF model pred: 9.119456e-08
[nems.xforms INFO] Evaluating: nems.xforms.predict
[nems.xforms INFO] Evaluating: nems.analysis.api.standard_correlation
[nems.xforms INFO] Evaluating: nems.xforms.plot_summary
[nems.modelspec INFO] Quickplot: no epoch specified, falling back to "REFERENCE"
[nems.modelspec WARNING] Quickplot: no valid epochs matching REFERENCE. Will not subset data.
[nems.modelspec INFO] plotting row 1/6
[nems.modelspec INFO] plotting row 2/6
[nems.modelspec INFO] plotting row 3/6
[nems.modelspec INFO] plotting row 4/6
[nems.modelspec INFO] plotting row 5/6
[nems.modelspec INFO] plotting row 6/6
[nems.modelspec INFO] Quickplot: generated fig with title "Cell: TAR010c-18-1, Batch: 271, None #0 dlog-wc.18x1.g-fir.1x15-lvl.1-dexp.1"
bin range: 0-500
../_images/demos_demo_xforms_tf_20_2.png