cc.mallet.topics
Class ParallelTopicModel

java.lang.Object
  extended by cc.mallet.topics.ParallelTopicModel
All Implemented Interfaces:
java.io.Serializable

public class ParallelTopicModel
extends java.lang.Object
implements java.io.Serializable

Simple parallel threaded implementation of LDA, following the UCI NIPS paper, with SparseLDA sampling scheme and data structure.

Author:
David Mimno, Andrew McCallum
See Also:
Serialized Form

Field Summary
protected  double[] alpha
           
protected  Alphabet alphabet
           
protected  double alphaSum
           
protected  double beta
           
protected  double betaSum
           
 int burninPeriod
           
protected  java.util.ArrayList<TopicAssignment> data
           
static double DEFAULT_BETA
           
protected  int[] docLengthCounts
           
protected  java.text.NumberFormat formatter
           
protected  java.lang.String modelFilename
           
 int numIterations
           
protected  int numTopics
           
protected  int numTypes
           
 int optimizeInterval
           
protected  boolean printLogLikelihood
           
protected  int randomSeed
           
protected  int saveModelInterval
           
 int saveSampleInterval
           
protected  int saveStateInterval
           
 int showTopicsInterval
           
protected  java.lang.String stateFilename
           
protected  int[] tokensPerTopic
           
protected  LabelAlphabet topicAlphabet
           
protected  int topicBits
           
protected  int[][] topicDocCounts
           
protected  int topicMask
           
protected  int[][] typeTopicCounts
           
 int wordsPerTopic
           
 
Constructor Summary
ParallelTopicModel(int numberOfTopics)
           
ParallelTopicModel(int numberOfTopics, double alphaSum, double beta)
           
ParallelTopicModel(LabelAlphabet topicAlphabet, double alphaSum, double beta)
           
 
Method Summary
 void addInstances(InstanceList training)
           
 void buildInitialTypeTopicCounts()
           
 void estimate()
           
 Alphabet getAlphabet()
           
 java.util.ArrayList<TopicAssignment> getData()
           
 TopicInferencer getInferencer()
           
 int getNumTopics()
           
 java.util.TreeSet[] getSortedWords()
          Return an array of sorted sets (one set per topic).
 LabelAlphabet getTopicAlphabet()
           
 java.lang.Object[][] getTopWords(int numWords)
          Return an array (one element for each topic) of arrays of words, which are the most probable words for that topic in descending order.
static void main(java.lang.String[] args)
           
 double modelLogLikelihood()
           
 void optimizeAlpha(WorkerRunnable[] runnables)
           
 void printDocumentTopics(java.io.File file)
           
 void printDocumentTopics(java.io.PrintWriter out)
           
 void printDocumentTopics(java.io.PrintWriter out, double threshold, int max)
           
 void printState(java.io.File f)
           
 void printState(java.io.PrintStream out)
           
 void printTopicWordWeights(java.io.File file)
           
 void printTopicWordWeights(java.io.PrintWriter out)
          Print an unnormalized weight for every word in every topic.
 void printTopWords(java.io.File file, int numWords, boolean useNewLines)
           
 void printTopWords(java.io.PrintStream out, int numWords, boolean usingNewLines)
           
 void printTypeTopicCounts(java.io.File file)
          Write the internal representation of type-topic counts (count/topic pairs in descending order by count) to a file.
static ParallelTopicModel read(java.io.File f)
           
 void setBurninPeriod(int burninPeriod)
           
 void setNumIterations(int numIterations)
           
 void setNumThreads(int threads)
           
 void setOptimizeInterval(int interval)
          Interval for optimizing Dirichlet hyperparameters
 void setRandomSeed(int seed)
           
 void setSaveSerializedModel(int interval, java.lang.String filename)
          Define how often and where to save a serialized model.
 void setSaveState(int interval, java.lang.String filename)
          Define how often and where to save a text representation of the current state.
 void setTopicDisplay(int interval, int n)
           
 void sumTypeTopicCounts(WorkerRunnable[] runnables)
           
 void write(java.io.File serializedModelFile)
           
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

data

protected java.util.ArrayList<TopicAssignment> data

alphabet

protected Alphabet alphabet

topicAlphabet

protected LabelAlphabet topicAlphabet

numTopics

protected int numTopics

topicMask

protected int topicMask

topicBits

protected int topicBits

numTypes

protected int numTypes

alpha

protected double[] alpha

alphaSum

protected double alphaSum

beta

protected double beta

betaSum

protected double betaSum

DEFAULT_BETA

public static final double DEFAULT_BETA
See Also:
Constant Field Values

typeTopicCounts

protected int[][] typeTopicCounts

tokensPerTopic

protected int[] tokensPerTopic

docLengthCounts

protected int[] docLengthCounts

topicDocCounts

protected int[][] topicDocCounts

numIterations

public int numIterations

burninPeriod

public int burninPeriod

saveSampleInterval

public int saveSampleInterval

optimizeInterval

public int optimizeInterval

showTopicsInterval

public int showTopicsInterval

wordsPerTopic

public int wordsPerTopic

saveStateInterval

protected int saveStateInterval

stateFilename

protected java.lang.String stateFilename

saveModelInterval

protected int saveModelInterval

modelFilename

protected java.lang.String modelFilename

randomSeed

protected int randomSeed

formatter

protected java.text.NumberFormat formatter

printLogLikelihood

protected boolean printLogLikelihood
Constructor Detail

ParallelTopicModel

public ParallelTopicModel(int numberOfTopics)

ParallelTopicModel

public ParallelTopicModel(int numberOfTopics,
                          double alphaSum,
                          double beta)

ParallelTopicModel

public ParallelTopicModel(LabelAlphabet topicAlphabet,
                          double alphaSum,
                          double beta)
Method Detail

getAlphabet

public Alphabet getAlphabet()

getTopicAlphabet

public LabelAlphabet getTopicAlphabet()

getNumTopics

public int getNumTopics()

getData

public java.util.ArrayList<TopicAssignment> getData()

setNumIterations

public void setNumIterations(int numIterations)

setBurninPeriod

public void setBurninPeriod(int burninPeriod)

setTopicDisplay

public void setTopicDisplay(int interval,
                            int n)

setRandomSeed

public void setRandomSeed(int seed)

setOptimizeInterval

public void setOptimizeInterval(int interval)
Interval for optimizing Dirichlet hyperparameters


setNumThreads

public void setNumThreads(int threads)

setSaveState

public void setSaveState(int interval,
                         java.lang.String filename)
Define how often and where to save a text representation of the current state. Files are GZipped.

Parameters:
interval - Save a copy of the state every interval iterations.
filename - Save the state to this file, with the iteration number as a suffix

setSaveSerializedModel

public void setSaveSerializedModel(int interval,
                                   java.lang.String filename)
Define how often and where to save a serialized model.

Parameters:
interval - Save a serialized model every interval iterations.
filename - Save to this file, with the iteration number as a suffix

addInstances

public void addInstances(InstanceList training)

buildInitialTypeTopicCounts

public void buildInitialTypeTopicCounts()

sumTypeTopicCounts

public void sumTypeTopicCounts(WorkerRunnable[] runnables)

optimizeAlpha

public void optimizeAlpha(WorkerRunnable[] runnables)

estimate

public void estimate()
              throws java.io.IOException
Throws:
java.io.IOException

printTopWords

public void printTopWords(java.io.File file,
                          int numWords,
                          boolean useNewLines)
                   throws java.io.IOException
Throws:
java.io.IOException

getSortedWords

public java.util.TreeSet[] getSortedWords()
Return an array of sorted sets (one set per topic). Each set contains IDSorter objects with integer keys into the alphabet. To get direct access to the Strings, use getTopWords().


getTopWords

public java.lang.Object[][] getTopWords(int numWords)
Return an array (one element for each topic) of arrays of words, which are the most probable words for that topic in descending order. These are returned as Objects, but will probably be Strings.

Parameters:
numWords - The maximum length of each topic's array of words (may be less).

printTopWords

public void printTopWords(java.io.PrintStream out,
                          int numWords,
                          boolean usingNewLines)

printTypeTopicCounts

public void printTypeTopicCounts(java.io.File file)
                          throws java.io.IOException
Write the internal representation of type-topic counts (count/topic pairs in descending order by count) to a file.

Throws:
java.io.IOException

printTopicWordWeights

public void printTopicWordWeights(java.io.File file)
                           throws java.io.IOException
Throws:
java.io.IOException

printTopicWordWeights

public void printTopicWordWeights(java.io.PrintWriter out)
                           throws java.io.IOException
Print an unnormalized weight for every word in every topic. Most of these will be equal to the smoothing parameter beta.

Throws:
java.io.IOException

printDocumentTopics

public void printDocumentTopics(java.io.File file)
                         throws java.io.IOException
Throws:
java.io.IOException

printDocumentTopics

public void printDocumentTopics(java.io.PrintWriter out)

printDocumentTopics

public void printDocumentTopics(java.io.PrintWriter out,
                                double threshold,
                                int max)
Parameters:
out - A print writer
threshold - Only print topics with proportion greater than this number
max - Print no more than this many topics

printState

public void printState(java.io.File f)
                throws java.io.IOException
Throws:
java.io.IOException

printState

public void printState(java.io.PrintStream out)

modelLogLikelihood

public double modelLogLikelihood()

getInferencer

public TopicInferencer getInferencer()

write

public void write(java.io.File serializedModelFile)

read

public static ParallelTopicModel read(java.io.File f)
                               throws java.lang.Exception
Throws:
java.lang.Exception

main

public static void main(java.lang.String[] args)