aboutsummaryrefslogtreecommitdiffstats
path: root/data.py
diff options
context:
space:
mode:
Diffstat (limited to 'data.py')
-rw-r--r--data.py177
1 files changed, 177 insertions, 0 deletions
diff --git a/data.py b/data.py
new file mode 100644
index 0000000..b3fa6d2
--- /dev/null
+++ b/data.py
@@ -0,0 +1,177 @@
+import logging
+import random
+import numpy
+
+import cPickle
+
+from picklable_itertools import iter_
+
+from fuel.datasets import Dataset
+from fuel.streams import DataStream
+from fuel.schemes import IterationScheme, ConstantScheme
+from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer
+
+import sys
+import os
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger(__name__)
+
+class QADataset(Dataset):
+ def __init__(self, path, vocab_file, n_entities, need_sep_token, **kwargs):
+ self.provides_sources = ('context', 'question', 'answer', 'candidates')
+
+ self.path = path
+
+ self.vocab = ['@entity%d' % i for i in range(n_entities)] + \
+ [w.rstrip('\n') for w in open(vocab_file)] + \
+ ['<UNK>', '@placeholder'] + \
+ (['<SEP>'] if need_sep_token else [])
+
+ self.n_entities = n_entities
+ self.vocab_size = len(self.vocab)
+ self.reverse_vocab = {w: i for i, w in enumerate(self.vocab)}
+
+ super(QADataset, self).__init__(**kwargs)
+
+ def to_word_id(self, w, cand_mapping):
+ if w in cand_mapping:
+ return cand_mapping[w]
+ elif w[:7] == '@entity':
+ raise ValueError("Unmapped entity token: %s"%w)
+ elif w in self.reverse_vocab:
+ return self.reverse_vocab[w]
+ else:
+ return self.reverse_vocab['<UNK>']
+
+ def to_word_ids(self, s, cand_mapping):
+ return numpy.array([self.to_word_id(x, cand_mapping) for x in s.split(' ')], dtype=numpy.int32)
+
+ def get_data(self, state=None, request=None):
+ if request is None or state is not None:
+ raise ValueError("Expected a request (name of a question file) and no state.")
+
+ lines = [l.rstrip('\n') for l in open(os.path.join(self.path, request))]
+
+ ctx = lines[2]
+ q = lines[4]
+ a = lines[6]
+ cand = [s.split(':')[0] for s in lines[8:]]
+
+ entities = range(self.n_entities)
+ while len(cand) > len(entities):
+ logger.warning("Too many entities (%d) for question: %s, using duplicate entity identifiers"
+ %(len(cand), request))
+ entities = entities + entities
+ random.shuffle(entities)
+ cand_mapping = {t: k for t, k in zip(cand, entities)}
+
+ ctx = self.to_word_ids(ctx, cand_mapping)
+ q = self.to_word_ids(q, cand_mapping)
+ cand = numpy.array([self.to_word_id(x, cand_mapping) for x in cand], dtype=numpy.int32)
+ a = numpy.int32(self.to_word_id(a, cand_mapping))
+
+ if not a < self.n_entities:
+ raise ValueError("Invalid answer token %d"%a)
+ if not numpy.all(cand < self.n_entities):
+ raise ValueError("Invalid candidate in list %s"%repr(cand))
+ if not numpy.all(ctx < self.vocab_size):
+ raise ValueError("Context word id out of bounds: %d"%int(ctx.max()))
+ if not numpy.all(ctx >= 0):
+ raise ValueError("Context word id negative: %d"%int(ctx.min()))
+ if not numpy.all(q < self.vocab_size):
+ raise ValueError("Question word id out of bounds: %d"%int(q.max()))
+ if not numpy.all(q >= 0):
+ raise ValueError("Question word id negative: %d"%int(q.min()))
+
+ return (ctx, q, a, cand)
+
+class QAIterator(IterationScheme):
+ requests_examples = True
+ def __init__(self, path, shuffle=False, **kwargs):
+ self.path = path
+ self.shuffle = shuffle
+
+ super(QAIterator, self).__init__(**kwargs)
+
+ def get_request_iterator(self):
+ l = [f for f in os.listdir(self.path)
+ if os.path.isfile(os.path.join(self.path, f))]
+ if self.shuffle:
+ random.shuffle(l)
+ return iter_(l)
+
+# -------------- DATASTREAM SETUP --------------------
+
+
+class ConcatCtxAndQuestion(Transformer):
+ produces_examples = True
+ def __init__(self, stream, concat_question_before, separator_token=None, **kwargs):
+ assert stream.sources == ('context', 'question', 'answer', 'candidates')
+ self.sources = ('question', 'answer', 'candidates')
+
+ self.sep = numpy.array([separator_token] if separator_token is not None else [],
+ dtype=numpy.int32)
+ self.concat_question_before = concat_question_before
+
+ super(ConcatCtxAndQuestion, self).__init__(stream, **kwargs)
+
+ def get_data(self, request=None):
+ if request is not None:
+ raise ValueError('Unsupported: request')
+
+ ctx, q, a, cand = next(self.child_epoch_iterator)
+
+ if self.concat_question_before:
+ return (numpy.concatenate([q, self.sep, ctx]), a, cand)
+ else:
+ return (numpy.concatenate([ctx, self.sep, q]), a, cand)
+
+class _balanced_batch_helper(object):
+ def __init__(self, key):
+ self.key = key
+ def __call__(self, data):
+ return data[self.key].shape[0]
+
+def setup_datastream(path, vocab_file, config):
+ ds = QADataset(path, vocab_file, config.n_entities, need_sep_token=config.concat_ctx_and_question)
+ it = QAIterator(path, shuffle=config.shuffle_questions)
+
+ stream = DataStream(ds, iteration_scheme=it)
+
+ if config.concat_ctx_and_question:
+ stream = ConcatCtxAndQuestion(stream, config.concat_question_before, ds.reverse_vocab['<SEP>'])
+
+ # Sort sets of multiple batches to make batches of similar sizes
+ stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size * config.sort_batch_count))
+ comparison = _balanced_batch_helper(stream.sources.index('question' if config.concat_ctx_and_question else 'context'))
+ stream = Mapping(stream, SortMapping(comparison))
+ stream = Unpack(stream)
+
+ stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size))
+ stream = Padding(stream, mask_sources=['context', 'question', 'candidates'], mask_dtype='int32')
+
+ return ds, stream
+
+if __name__ == "__main__":
+ # Test
+ class DummyConfig:
+ def __init__(self):
+ self.shuffle_entities = True
+ self.shuffle_questions = False
+ self.concat_ctx_and_question = False
+ self.concat_question_before = False
+ self.batch_size = 2
+ self.sort_batch_count = 1000
+
+ ds, stream = setup_datastream(os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training"),
+ os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt"),
+ DummyConfig())
+ it = stream.get_epoch_iterator()
+
+ for i, d in enumerate(stream.get_epoch_iterator()):
+ print '--'
+ print d
+ if i > 2: break
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :