aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2015-12-21 10:34:16 +0100
committerAlex Auvolat <alex@adnab.me>2015-12-21 10:34:16 +0100
commit84f6dc7e109d822aa6bd5440b7a650c98ab69edb (patch)
tree3730c087dfc7a933f4a37ea39801c15e69ba811b
parentb543ac7f330fa7b9f1400d76d1720aba1e550d6c (diff)
downloadpgm-ctc-84f6dc7e109d822aa6bd5440b7a650c98ab69edb.tar.gz
pgm-ctc-84f6dc7e109d822aa6bd5440b7a650c98ab69edb.zip
minifix
-rw-r--r--ctc.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/ctc.py b/ctc.py
index cf629c1..a0dab80 100644
--- a/ctc.py
+++ b/ctc.py
@@ -34,7 +34,7 @@ class CTC(Brick):
l_blk = tensor.zeros((S, B))
l_blk = tensor.set_subtensor(l_blk[1::2,:],l)
- # dimension of alpha :
+ # dimension of alpha (corresponds to alpha hat in the paper) :
# T x B x S
# dimension of c :
# T x B
@@ -44,7 +44,8 @@ class CTC(Brick):
probs[0][tensor.arange(B), l[0]],
tensor.zeros((B, S-2))
], axis=1)
- c0 = alpha0.sum(axis=2)
+ c0 = alpha0.sum(axis=1)
+ alpha0 = alpha0 / c0[:,None]
# recursion
def recursion(p, p_mask, prev_alpha, prev_c):