Source code for nltk.test.unit.test_hmm

# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
from nltk import hmm

def _wikipedia_example_hmm():
    # Example from wikipedia
    # (http://en.wikipedia.org/wiki/Forward%E2%80%93backward_algorithm)

    states = ['rain', 'no rain']
    symbols = ['umbrella', 'no umbrella']

    A = [[0.7, 0.3], [0.3, 0.7]]  # transition probabilities
    B = [[0.9, 0.1], [0.2, 0.8]]  # emission probabilities
    pi = [0.5, 0.5]  # initial probabilities

    seq = ['umbrella', 'umbrella', 'no umbrella', 'umbrella', 'umbrella']
    seq = list(zip(seq, [None]*len(seq)))

    model = hmm._create_hmm_tagger(states, symbols, A, B, pi)
    return model, states, symbols, seq


[docs]def test_forward_probability(): from numpy.testing import assert_array_almost_equal # example from p. 385, Huang et al model, states, symbols = hmm._market_hmm_example() seq = [('up', None), ('up', None)] expected = [ [0.09, 0.02, 0.35], [0.0357, 0.0085, 0.1792] ] fp = 2**model._forward_probability(seq) assert_array_almost_equal(fp, expected)
[docs]def test_forward_probability2(): from numpy.testing import assert_array_almost_equal model, states, symbols, seq = _wikipedia_example_hmm() fp = 2**model._forward_probability(seq) # examples in wikipedia are normalized fp = (fp.T / fp.sum(axis=1)).T # results are swapped to match our states order # FIXME: is it possible to make order stable? wikipedia_results = [ [0.1818, 0.8182], [0.1166, 0.8834], [0.8093, 0.1907], [0.2692, 0.7308], [0.1327, 0.8673], ] assert_array_almost_equal(wikipedia_results, fp, 4)
[docs]def test_backward_probability(): from numpy.testing import assert_array_almost_equal model, states, symbols, seq = _wikipedia_example_hmm() bp = 2**model._backward_probability(seq) # examples in wikipedia are normalized bp = (bp.T / bp.sum(axis=1)).T wikipedia_results = [ # Forward-backward algorithm doesn't need b0_5, # so .backward_probability doesn't compute it. # [0.3531, 0.6469], [0.4077, 0.5923], [0.6237, 0.3763], [0.3467, 0.6533], [0.3727, 0.6273], [0.5, 0.5], ] assert_array_almost_equal(wikipedia_results, bp, 4)
[docs]def setup_module(module): from nose import SkipTest try: import numpy except ImportError: raise SkipTest("numpy is required for nltk.test.test_hmm")