cc.mallet.topics
Class LDAHyper

java.lang.Object
  extended by cc.mallet.topics.LDAHyper
All Implemented Interfaces:
java.io.Serializable
Direct Known Subclasses:
DMRTopicModel, LDAStream

public class LDAHyper
extends java.lang.Object
implements java.io.Serializable

Latent Dirichlet Allocation with optimized hyperparameters

Author:
David Mimno, Andrew McCallum
See Also:
Serialized Form

Nested Class Summary
 class LDAHyper.Topication
           
 
Field Summary
protected  double[] alpha
           
protected  Alphabet alphabet
           
protected  double alphaSum
           
protected  double beta
           
protected  double betaSum
           
 int burninPeriod
           
protected  double[] cachedCoefficients
           
protected  java.util.ArrayList<LDAHyper.Topication> data
           
static double DEFAULT_BETA
           
protected  int[] docLengthCounts
           
protected  java.text.NumberFormat formatter
           
 int iterationsSoFar
           
 int numIterations
           
protected  int numTopics
           
protected  int numTypes
           
protected  int[] oneDocTopicCounts
           
 int optimizeInterval
           
protected  java.lang.String outputModelFilename
           
protected  int outputModelInterval
           
protected  boolean printLogLikelihood
           
protected  Randoms random
           
 int saveSampleInterval
           
protected  int saveStateInterval
           
 int showTopicsInterval
           
protected  double smoothingOnlyMass
           
protected  java.lang.String stateFilename
           
protected  InstanceList testing
           
protected  int[] tokensPerTopic
           
protected  LabelAlphabet topicAlphabet
           
protected  int[][] topicDocCounts
           
protected  gnu.trove.TIntIntHashMap[] typeTopicCounts
           
 int wordsPerTopic
           
 
Constructor Summary
LDAHyper(int numberOfTopics)
           
LDAHyper(int numberOfTopics, double alphaSum, double beta)
           
LDAHyper(int numberOfTopics, double alphaSum, double beta, Randoms random)
           
LDAHyper(LabelAlphabet topicAlphabet, double alphaSum, double beta, Randoms random)
           
 
Method Summary
 void addInstances(InstanceList training)
           
 void addInstances(InstanceList training, java.util.List<LabelSequence> topics)
           
 double empiricalLikelihood(int numSamples, InstanceList testing)
           
 void estimate()
           
 void estimate(int iterationsThisRound)
           
 Alphabet getAlphabet()
           
 int getCountFeatureTopic(int featureIndex, int topicIndex)
           
 int getCountTokensPerTopic(int topicIndex)
           
 java.util.ArrayList<LDAHyper.Topication> getData()
           
 int getNumTopics()
           
 IDSorter[] getSortedTopicWords(int topic)
           
 LabelAlphabet getTopicAlphabet()
           
protected  void initializeHistogramsAndCachedValues()
          Gather statistics on the size of documents and create histograms for use in Dirichlet hyperparameter optimization.
protected  int instanceLength(Instance instance)
           
static void main(java.lang.String[] args)
           
 double modelLogLikelihood()
           
 void printDocumentTopics(java.io.File f)
           
 void printDocumentTopics(java.io.PrintWriter pw)
           
 void printDocumentTopics(java.io.PrintWriter pw, double threshold, int max)
           
 void printState(java.io.File f)
           
 void printState(java.io.PrintStream out)
           
 void printTopWords(java.io.File file, int numWords, boolean useNewLines)
           
 void printTopWords(java.io.PrintStream out, int numWords, boolean usingNewLines)
           
static LDAHyper read(java.io.File f)
           
protected  void sampleTopicsForOneDoc(FeatureSequence tokenSequence, FeatureSequence topicSequence, boolean shouldSaveState, boolean readjustTopicsAndStats)
           
 void setBurninPeriod(int burninPeriod)
           
 void setModelOutput(int interval, java.lang.String filename)
           
 void setNumIterations(int numIterations)
           
 void setOptimizeInterval(int interval)
           
 void setRandomSeed(int seed)
           
 void setSaveState(int interval, java.lang.String filename)
          Define how often and where to save the state
 void setTestingInstances(InstanceList testing)
          Held-out instances for empirical likelihood calculation
 void setTopicDisplay(int interval, int n)
           
 double topicLabelMutualInformation()
           
 void topicXMLReport(java.io.PrintWriter out, int numWords)
           
 void write(java.io.File f)
           
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

data

protected java.util.ArrayList<LDAHyper.Topication> data

alphabet

protected Alphabet alphabet

topicAlphabet

protected LabelAlphabet topicAlphabet

numTopics

protected int numTopics

numTypes

protected int numTypes

alpha

protected double[] alpha

alphaSum

protected double alphaSum

beta

protected double beta

betaSum

protected double betaSum

DEFAULT_BETA

public static final double DEFAULT_BETA
See Also:
Constant Field Values

smoothingOnlyMass

protected double smoothingOnlyMass

cachedCoefficients

protected double[] cachedCoefficients

testing

protected InstanceList testing

oneDocTopicCounts

protected int[] oneDocTopicCounts

typeTopicCounts

protected gnu.trove.TIntIntHashMap[] typeTopicCounts

tokensPerTopic

protected int[] tokensPerTopic

docLengthCounts

protected int[] docLengthCounts

topicDocCounts

protected int[][] topicDocCounts

iterationsSoFar

public int iterationsSoFar

numIterations

public int numIterations

burninPeriod

public int burninPeriod

saveSampleInterval

public int saveSampleInterval

optimizeInterval

public int optimizeInterval

showTopicsInterval

public int showTopicsInterval

wordsPerTopic

public int wordsPerTopic

outputModelInterval

protected int outputModelInterval

outputModelFilename

protected java.lang.String outputModelFilename

saveStateInterval

protected int saveStateInterval

stateFilename

protected java.lang.String stateFilename

random

protected Randoms random

formatter

protected java.text.NumberFormat formatter

printLogLikelihood

protected boolean printLogLikelihood
Constructor Detail

LDAHyper

public LDAHyper(int numberOfTopics)

LDAHyper

public LDAHyper(int numberOfTopics,
                double alphaSum,
                double beta)

LDAHyper

public LDAHyper(int numberOfTopics,
                double alphaSum,
                double beta,
                Randoms random)

LDAHyper

public LDAHyper(LabelAlphabet topicAlphabet,
                double alphaSum,
                double beta,
                Randoms random)
Method Detail

getAlphabet

public Alphabet getAlphabet()

getTopicAlphabet

public LabelAlphabet getTopicAlphabet()

getNumTopics

public int getNumTopics()

getData

public java.util.ArrayList<LDAHyper.Topication> getData()

getCountFeatureTopic

public int getCountFeatureTopic(int featureIndex,
                                int topicIndex)

getCountTokensPerTopic

public int getCountTokensPerTopic(int topicIndex)

setTestingInstances

public void setTestingInstances(InstanceList testing)
Held-out instances for empirical likelihood calculation


setNumIterations

public void setNumIterations(int numIterations)

setBurninPeriod

public void setBurninPeriod(int burninPeriod)

setTopicDisplay

public void setTopicDisplay(int interval,
                            int n)

setRandomSeed

public void setRandomSeed(int seed)

setOptimizeInterval

public void setOptimizeInterval(int interval)

setModelOutput

public void setModelOutput(int interval,
                           java.lang.String filename)

setSaveState

public void setSaveState(int interval,
                         java.lang.String filename)
Define how often and where to save the state

Parameters:
interval - Save a copy of the state every interval iterations.
filename - Save the state to this file, with the iteration number as a suffix

instanceLength

protected int instanceLength(Instance instance)

addInstances

public void addInstances(InstanceList training)

addInstances

public void addInstances(InstanceList training,
                         java.util.List<LabelSequence> topics)

initializeHistogramsAndCachedValues

protected void initializeHistogramsAndCachedValues()
Gather statistics on the size of documents and create histograms for use in Dirichlet hyperparameter optimization.


estimate

public void estimate()
              throws java.io.IOException
Throws:
java.io.IOException

estimate

public void estimate(int iterationsThisRound)
              throws java.io.IOException
Throws:
java.io.IOException

sampleTopicsForOneDoc

protected void sampleTopicsForOneDoc(FeatureSequence tokenSequence,
                                     FeatureSequence topicSequence,
                                     boolean shouldSaveState,
                                     boolean readjustTopicsAndStats)

getSortedTopicWords

public IDSorter[] getSortedTopicWords(int topic)

printTopWords

public void printTopWords(java.io.File file,
                          int numWords,
                          boolean useNewLines)
                   throws java.io.IOException
Throws:
java.io.IOException

printTopWords

public void printTopWords(java.io.PrintStream out,
                          int numWords,
                          boolean usingNewLines)

topicXMLReport

public void topicXMLReport(java.io.PrintWriter out,
                           int numWords)

printDocumentTopics

public void printDocumentTopics(java.io.File f)
                         throws java.io.IOException
Throws:
java.io.IOException

printDocumentTopics

public void printDocumentTopics(java.io.PrintWriter pw)

printDocumentTopics

public void printDocumentTopics(java.io.PrintWriter pw,
                                double threshold,
                                int max)
Parameters:
pw - A print writer
threshold - Only print topics with proportion greater than this number
max - Print no more than this many topics

printState

public void printState(java.io.File f)
                throws java.io.IOException
Throws:
java.io.IOException

printState

public void printState(java.io.PrintStream out)

write

public void write(java.io.File f)

read

public static LDAHyper read(java.io.File f)

topicLabelMutualInformation

public double topicLabelMutualInformation()

empiricalLikelihood

public double empiricalLikelihood(int numSamples,
                                  InstanceList testing)

modelLogLikelihood

public double modelLogLikelihood()

main

public static void main(java.lang.String[] args)
                 throws java.io.IOException
Throws:
java.io.IOException