cc.mallet.topics
Class LDAHyper
java.lang.Object
cc.mallet.topics.LDAHyper
- All Implemented Interfaces:
- java.io.Serializable
- Direct Known Subclasses:
- DMRTopicModel, LDAStream
public class LDAHyper
- extends java.lang.Object
- implements java.io.Serializable
Latent Dirichlet Allocation with optimized hyperparameters
- Author:
- David Mimno, Andrew McCallum
- See Also:
- Serialized Form
Method Summary |
void |
addInstances(InstanceList training)
|
void |
addInstances(InstanceList training,
java.util.List<LabelSequence> topics)
|
double |
empiricalLikelihood(int numSamples,
InstanceList testing)
|
void |
estimate()
|
void |
estimate(int iterationsThisRound)
|
Alphabet |
getAlphabet()
|
int |
getCountFeatureTopic(int featureIndex,
int topicIndex)
|
int |
getCountTokensPerTopic(int topicIndex)
|
java.util.ArrayList<LDAHyper.Topication> |
getData()
|
int |
getNumTopics()
|
IDSorter[] |
getSortedTopicWords(int topic)
|
LabelAlphabet |
getTopicAlphabet()
|
protected void |
initializeHistogramsAndCachedValues()
Gather statistics on the size of documents
and create histograms for use in Dirichlet hyperparameter
optimization. |
protected int |
instanceLength(Instance instance)
|
static void |
main(java.lang.String[] args)
|
double |
modelLogLikelihood()
|
void |
printDocumentTopics(java.io.File f)
|
void |
printDocumentTopics(java.io.PrintWriter pw)
|
void |
printDocumentTopics(java.io.PrintWriter pw,
double threshold,
int max)
|
void |
printState(java.io.File f)
|
void |
printState(java.io.PrintStream out)
|
void |
printTopWords(java.io.File file,
int numWords,
boolean useNewLines)
|
void |
printTopWords(java.io.PrintStream out,
int numWords,
boolean usingNewLines)
|
static LDAHyper |
read(java.io.File f)
|
protected void |
sampleTopicsForOneDoc(FeatureSequence tokenSequence,
FeatureSequence topicSequence,
boolean shouldSaveState,
boolean readjustTopicsAndStats)
|
void |
setBurninPeriod(int burninPeriod)
|
void |
setModelOutput(int interval,
java.lang.String filename)
|
void |
setNumIterations(int numIterations)
|
void |
setOptimizeInterval(int interval)
|
void |
setRandomSeed(int seed)
|
void |
setSaveState(int interval,
java.lang.String filename)
Define how often and where to save the state |
void |
setTestingInstances(InstanceList testing)
Held-out instances for empirical likelihood calculation |
void |
setTopicDisplay(int interval,
int n)
|
double |
topicLabelMutualInformation()
|
void |
topicXMLReport(java.io.PrintWriter out,
int numWords)
|
void |
write(java.io.File f)
|
Methods inherited from class java.lang.Object |
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
data
protected java.util.ArrayList<LDAHyper.Topication> data
alphabet
protected Alphabet alphabet
topicAlphabet
protected LabelAlphabet topicAlphabet
numTopics
protected int numTopics
numTypes
protected int numTypes
alpha
protected double[] alpha
alphaSum
protected double alphaSum
beta
protected double beta
betaSum
protected double betaSum
DEFAULT_BETA
public static final double DEFAULT_BETA
- See Also:
- Constant Field Values
smoothingOnlyMass
protected double smoothingOnlyMass
cachedCoefficients
protected double[] cachedCoefficients
testing
protected InstanceList testing
oneDocTopicCounts
protected int[] oneDocTopicCounts
typeTopicCounts
protected gnu.trove.TIntIntHashMap[] typeTopicCounts
tokensPerTopic
protected int[] tokensPerTopic
docLengthCounts
protected int[] docLengthCounts
topicDocCounts
protected int[][] topicDocCounts
iterationsSoFar
public int iterationsSoFar
numIterations
public int numIterations
burninPeriod
public int burninPeriod
saveSampleInterval
public int saveSampleInterval
optimizeInterval
public int optimizeInterval
showTopicsInterval
public int showTopicsInterval
wordsPerTopic
public int wordsPerTopic
outputModelInterval
protected int outputModelInterval
outputModelFilename
protected java.lang.String outputModelFilename
saveStateInterval
protected int saveStateInterval
stateFilename
protected java.lang.String stateFilename
random
protected Randoms random
formatter
protected java.text.NumberFormat formatter
printLogLikelihood
protected boolean printLogLikelihood
LDAHyper
public LDAHyper(int numberOfTopics)
LDAHyper
public LDAHyper(int numberOfTopics,
double alphaSum,
double beta)
LDAHyper
public LDAHyper(int numberOfTopics,
double alphaSum,
double beta,
Randoms random)
LDAHyper
public LDAHyper(LabelAlphabet topicAlphabet,
double alphaSum,
double beta,
Randoms random)
getAlphabet
public Alphabet getAlphabet()
getTopicAlphabet
public LabelAlphabet getTopicAlphabet()
getNumTopics
public int getNumTopics()
getData
public java.util.ArrayList<LDAHyper.Topication> getData()
getCountFeatureTopic
public int getCountFeatureTopic(int featureIndex,
int topicIndex)
getCountTokensPerTopic
public int getCountTokensPerTopic(int topicIndex)
setTestingInstances
public void setTestingInstances(InstanceList testing)
- Held-out instances for empirical likelihood calculation
setNumIterations
public void setNumIterations(int numIterations)
setBurninPeriod
public void setBurninPeriod(int burninPeriod)
setTopicDisplay
public void setTopicDisplay(int interval,
int n)
setRandomSeed
public void setRandomSeed(int seed)
setOptimizeInterval
public void setOptimizeInterval(int interval)
setModelOutput
public void setModelOutput(int interval,
java.lang.String filename)
setSaveState
public void setSaveState(int interval,
java.lang.String filename)
- Define how often and where to save the state
- Parameters:
interval
- Save a copy of the state every interval
iterations.filename
- Save the state to this file, with the iteration number as a suffix
instanceLength
protected int instanceLength(Instance instance)
addInstances
public void addInstances(InstanceList training)
addInstances
public void addInstances(InstanceList training,
java.util.List<LabelSequence> topics)
initializeHistogramsAndCachedValues
protected void initializeHistogramsAndCachedValues()
- Gather statistics on the size of documents
and create histograms for use in Dirichlet hyperparameter
optimization.
estimate
public void estimate()
throws java.io.IOException
- Throws:
java.io.IOException
estimate
public void estimate(int iterationsThisRound)
throws java.io.IOException
- Throws:
java.io.IOException
sampleTopicsForOneDoc
protected void sampleTopicsForOneDoc(FeatureSequence tokenSequence,
FeatureSequence topicSequence,
boolean shouldSaveState,
boolean readjustTopicsAndStats)
getSortedTopicWords
public IDSorter[] getSortedTopicWords(int topic)
printTopWords
public void printTopWords(java.io.File file,
int numWords,
boolean useNewLines)
throws java.io.IOException
- Throws:
java.io.IOException
printTopWords
public void printTopWords(java.io.PrintStream out,
int numWords,
boolean usingNewLines)
topicXMLReport
public void topicXMLReport(java.io.PrintWriter out,
int numWords)
printDocumentTopics
public void printDocumentTopics(java.io.File f)
throws java.io.IOException
- Throws:
java.io.IOException
printDocumentTopics
public void printDocumentTopics(java.io.PrintWriter pw)
printDocumentTopics
public void printDocumentTopics(java.io.PrintWriter pw,
double threshold,
int max)
- Parameters:
pw
- A print writerthreshold
- Only print topics with proportion greater than this numbermax
- Print no more than this many topics
printState
public void printState(java.io.File f)
throws java.io.IOException
- Throws:
java.io.IOException
printState
public void printState(java.io.PrintStream out)
write
public void write(java.io.File f)
read
public static LDAHyper read(java.io.File f)
topicLabelMutualInformation
public double topicLabelMutualInformation()
empiricalLikelihood
public double empiricalLikelihood(int numSamples,
InstanceList testing)
modelLogLikelihood
public double modelLogLikelihood()
main
public static void main(java.lang.String[] args)
throws java.io.IOException
- Throws:
java.io.IOException