aboutsummaryrefslogtreecommitdiffstats
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py112
1 files changed, 112 insertions, 0 deletions
diff --git a/train.py b/train.py
new file mode 100755
index 0000000..84f9a4b
--- /dev/null
+++ b/train.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env python
+
+import logging
+import numpy
+import sys
+import os
+import importlib
+
+import theano
+from theano import tensor
+
+from blocks.extensions import Printing, SimpleExtension, FinishAfter, ProgressBar
+from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
+from blocks.graph import ComputationGraph
+from blocks.main_loop import MainLoop
+from blocks.model import Model
+from blocks.algorithms import GradientDescent
+
+try:
+ from blocks.extras.extensions.plot import Plot
+ plot_avail = True
+except ImportError:
+ plot_avail = False
+ print "No plotting extension available."
+
+import data
+from paramsaveload import SaveLoadParams
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger(__name__)
+
+sys.setrecursionlimit(500000)
+
+if __name__ == "__main__":
+ if len(sys.argv) != 2:
+ print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
+ sys.exit(1)
+ model_name = sys.argv[1]
+ config = importlib.import_module('.%s' % model_name, 'config')
+
+ # Build datastream
+ path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training")
+ valid_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/validation")
+ vocab_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt")
+
+ ds, train_stream = data.setup_datastream(path, vocab_path, config)
+ _, valid_stream = data.setup_datastream(valid_path, vocab_path, config)
+
+ dump_path = os.path.join("model_params", model_name+".pkl")
+
+ # Build model
+ m = config.Model(config, ds.vocab_size)
+
+ # Build the Blocks stuff for training
+ model = Model(m.sgd_cost)
+
+ algorithm = GradientDescent(cost=m.sgd_cost,
+ step_rule=config.step_rule,
+ parameters=model.parameters)
+
+ extensions = [
+ TrainingDataMonitoring(
+ [v for l in m.monitor_vars for v in l],
+ prefix='train',
+ every_n_batches=config.print_freq)
+ ]
+ if config.save_freq is not None and dump_path is not None:
+ extensions += [
+ SaveLoadParams(path=dump_path,
+ model=model,
+ before_training=True,
+ after_training=True,
+ after_epoch=True,
+ every_n_batches=config.save_freq)
+ ]
+ if valid_stream is not None and config.valid_freq != -1:
+ extensions += [
+ DataStreamMonitoring(
+ [v for l in m.monitor_vars_valid for v in l],
+ valid_stream,
+ prefix='valid',
+ every_n_batches=config.valid_freq),
+ ]
+ if plot_avail:
+ plot_channels = [['train_' + v.name for v in lt] + ['valid_' + v.name for v in lv]
+ for lt, lv in zip(m.monitor_vars, m.monitor_vars_valid)]
+ extensions += [
+ Plot(document='deepmind_qa_'+model_name,
+ channels=plot_channels,
+ # server_url='http://localhost:5006/', # If you need, change this
+ every_n_batches=config.print_freq)
+ ]
+ extensions += [
+ Printing(every_n_batches=config.print_freq,
+ after_epoch=True),
+ ProgressBar()
+ ]
+
+ main_loop = MainLoop(
+ model=model,
+ data_stream=train_stream,
+ algorithm=algorithm,
+ extensions=extensions
+ )
+
+ # Run the model !
+ main_loop.run()
+ main_loop.profile.report()
+
+
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :