Source code for nltk.test.unit.test_classify

# -*- coding: utf-8 -*-
"""
Unit tests for nltk.classify. See also: nltk/test/classify.doctest
"""
from __future__ import absolute_import
from nose import SkipTest
from nltk import classify

TRAIN = [
     (dict(a=1,b=1,c=1), 'y'),
     (dict(a=1,b=1,c=1), 'x'),
     (dict(a=1,b=1,c=0), 'y'),
     (dict(a=0,b=1,c=1), 'x'),
     (dict(a=0,b=1,c=1), 'y'),
     (dict(a=0,b=0,c=1), 'y'),
     (dict(a=0,b=1,c=0), 'x'),
     (dict(a=0,b=0,c=0), 'x'),
     (dict(a=0,b=1,c=1), 'y'),
 ]

TEST = [
     (dict(a=1,b=0,c=1)), # unseen
     (dict(a=1,b=0,c=0)), # unseen
     (dict(a=0,b=1,c=1)), # seen 3 times, labels=y,y,x
     (dict(a=0,b=1,c=0)), # seen 1 time, label=x
]

RESULTS = [
    (0.16,  0.84),
    (0.46,  0.54),
    (0.41,  0.59),
    (0.76,  0.24),
]

[docs]def assert_classifier_correct(algorithm): try: classifier = classify.MaxentClassifier.train( TRAIN, algorithm, trace=0, max_iter=1000 ) except (LookupError, AttributeError) as e: raise SkipTest(str(e)) for (px, py), featureset in zip(RESULTS, TEST): pdist = classifier.prob_classify(featureset) assert abs(pdist.prob('x') - px) < 1e-2, (pdist.prob('x'), px) assert abs(pdist.prob('y') - py) < 1e-2, (pdist.prob('y'), py)
[docs]def test_megam(): assert_classifier_correct('MEGAM')
[docs]def test_tadm(): assert_classifier_correct('TADM')