aboutsummaryrefslogtreecommitdiffstats
path: root/train.py
blob: 84f9a4b900458976a69d478a0ed90cdd0685ea36 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

#!/usr/bin/env python

import logging
import numpy
import sys
import os
import importlib

import theano
from theano import tensor

from blocks.extensions import Printing, SimpleExtension, FinishAfter, ProgressBar
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
from blocks.model import Model
from blocks.algorithms import GradientDescent

try:
    from blocks.extras.extensions.plot import Plot
    plot_avail = True
except ImportError:
    plot_avail = False
    print "No plotting extension available."

import data
from paramsaveload import SaveLoadParams

logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)

sys.setrecursionlimit(500000)

if __name__ == "__main__":
    if len(sys.argv) != 2:
        print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
        sys.exit(1)
    model_name = sys.argv[1]
    config = importlib.import_module('.%s' % model_name, 'config')

    # Build datastream
    path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training")
    valid_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/validation")
    vocab_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt")

    ds, train_stream = data.setup_datastream(path, vocab_path, config)
    _, valid_stream = data.setup_datastream(valid_path, vocab_path, config)

    dump_path = os.path.join("model_params", model_name+".pkl")

    # Build model
    m = config.Model(config, ds.vocab_size)

    # Build the Blocks stuff for training
    model = Model(m.sgd_cost)

    algorithm = GradientDescent(cost=m.sgd_cost,
                                step_rule=config.step_rule,
                                parameters=model.parameters)

    extensions = [
            TrainingDataMonitoring(
                [v for l in m.monitor_vars for v in l],
                prefix='train',
                every_n_batches=config.print_freq)
    ]
    if config.save_freq is not None and dump_path is not None:
        extensions += [
            SaveLoadParams(path=dump_path,
                           model=model,
                           before_training=True,
                           after_training=True,
                           after_epoch=True,
                           every_n_batches=config.save_freq)
        ]
    if valid_stream is not None and config.valid_freq != -1:
        extensions += [
            DataStreamMonitoring(
                [v for l in m.monitor_vars_valid for v in l],
                valid_stream,
                prefix='valid',
                every_n_batches=config.valid_freq),
        ]
    if plot_avail:
        plot_channels = [['train_' + v.name for v in lt] + ['valid_' + v.name for v in lv]
                         for lt, lv in zip(m.monitor_vars, m.monitor_vars_valid)]
        extensions += [
            Plot(document='deepmind_qa_'+model_name,
                 channels=plot_channels,
                 # server_url='http://localhost:5006/', # If you need, change this
                 every_n_batches=config.print_freq)
        ]
    extensions += [
            Printing(every_n_batches=config.print_freq,
                     after_epoch=True),
            ProgressBar()
    ]

    main_loop = MainLoop(
        model=model,
        data_stream=train_stream,
        algorithm=algorithm,
        extensions=extensions
    )

    # Run the model !
    main_loop.run()
    main_loop.profile.report()



#  vim: set sts=4 ts=4 sw=4 tw=0 et :