00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _SIMPLEKLRETMETHOD_HPP
00013 #define _SIMPLEKLRETMETHOD_HPP
00014
00015 #include <cmath>
00016 #include "UnigramLM.hpp"
00017 #include "ScoreFunction.hpp"
00018 #include "SimpleKLDocModel.hpp"
00019 #include "TextQueryRep.hpp"
00020 #include "TextQueryRetMethod.hpp"
00021 #include "Counter.hpp"
00022 #include "DocUnigramCounter.hpp"
00023
00025
00026 class SimpleKLQueryModel : public ArrayQueryRep {
00027 public:
00029 SimpleKLQueryModel(const TermQuery &qry, const Index &dbIndex) :
00030 ArrayQueryRep(dbIndex.termCountUnique()+1, qry, dbIndex), qm(NULL),
00031 ind(dbIndex), colKLComputed(false) {
00032 colQLikelihood = 0;
00033 colQueryLikelihood();
00034 }
00035
00037 SimpleKLQueryModel(const Index &dbIndex) :
00038 ArrayQueryRep(dbIndex.termCountUnique()+1), qm(NULL), ind(dbIndex),
00039 colKLComputed(false) {
00040 colQLikelihood = 0;
00041 startIteration();
00042 while (hasMore()) {
00043 QueryTerm *qt = nextTerm();
00044 setCount(qt->id(), 0);
00045 delete qt;
00046 }
00047 }
00048
00049
00050 virtual ~SimpleKLQueryModel(){ if (qm) delete qm;}
00051
00052
00054
00061 virtual void interpolateWith(const UnigramLM &qModel, double origModCoeff,
00062 int howManyWord, double prSumThresh=1,
00063 double prThresh=0);
00064 virtual double scoreConstant() const {
00065 return totalCount();
00066 }
00067
00069 virtual void load(istream &is);
00070
00072 virtual void save(ostream &os);
00073
00075 virtual void clarity(ostream &os);
00077 virtual double clarity() const;
00078
00080 double colDivergence() const {
00081 if (colKLComputed) {
00082 return colKL;
00083 } else {
00084 colKLComputed = true;
00085 double d=0;
00086 startIteration();
00087 while (hasMore()) {
00088 QueryTerm *qt=nextTerm();
00089 double pr = qt->weight()/(double)totalCount();
00090 double colPr = ((double)ind.termCount(qt->id()) /
00091 (double)(ind.termCount()));
00092 d += pr*log(pr/colPr);
00093 delete qt;
00094 }
00095 colKL=d;
00096 return d;
00097 }
00098 }
00099
00101 double KLDivergence(const UnigramLM &refMod) {
00102 double d=0;
00103 startIteration();
00104 while (hasMore()) {
00105 QueryTerm *qt=nextTerm();
00106 double pr = qt->weight()/(double)totalCount();
00107 d += pr*log(pr/refMod.prob(qt->id()));
00108 delete qt;
00109 }
00110 return d;
00111 }
00112
00113 double colQueryLikelihood() const {
00114 if (colQLikelihood == 0) {
00115
00116 COUNT_T tc = ind.termCount();
00117 startIteration();
00118 while (hasMore()) {
00119 QueryTerm *qt = nextTerm();
00120 TERMID_T id = qt->id();
00121 double qtf = qt->weight();
00122 COUNT_T qtcf = ind.termCount(id);
00123 double s = qtf * log((double)qtcf/(double)tc);
00124 colQLikelihood += s;
00125 delete qt;
00126 }
00127 }
00128 return colQLikelihood;
00129 }
00130
00131
00132 protected:
00133
00134 mutable double colQLikelihood;
00135 mutable double colKL;
00136 mutable bool colKLComputed;
00137
00138 IndexedRealVector *qm;
00139 const Index &ind;
00140 };
00141
00142
00143
00145
00160 class SimpleKLScoreFunc : public ScoreFunction {
00161 public:
00162 enum SimpleKLParameter::adjustedScoreMethods adjScoreMethod;
00163 void setScoreMethod(enum SimpleKLParameter::adjustedScoreMethods adj) {
00164 adjScoreMethod = adj;
00165 }
00166 virtual double matchedTermWeight(const QueryTerm *qTerm,
00167 const TextQueryRep *qRep,
00168 const DocInfo *info,
00169 const DocumentRep *dRep) const {
00170 double w = qTerm->weight();
00171 double d = dRep->termWeight(qTerm->id(),info);
00172 double l = log(d);
00173 double score = w*l;
00174
00175
00176
00177
00178 return score;
00179
00180 }
00182 virtual double adjustedScore(double origScore,
00183 const TextQueryRep *qRep,
00184 const DocumentRep *dRep) const {
00185 const SimpleKLQueryModel *qm = dynamic_cast<const SimpleKLQueryModel *>(qRep);
00186
00187
00188
00189
00190 double qsc = qm->scoreConstant();
00191 double dsc = log(dRep->scoreConstant());
00192 double cql = qm->colQueryLikelihood();
00193
00194 double s = dsc * qsc + origScore + cql;
00195 double qsNorm = origScore/qsc;
00196 double qmD = qm->colDivergence();
00197
00198
00199
00200
00202 switch (adjScoreMethod) {
00203 case SimpleKLParameter::QUERYLIKELIHOOD:
00205
00206
00207
00208 return s;
00209
00210 case SimpleKLParameter::CROSSENTROPY:
00212
00213 assert(qm->scoreConstant()!=0);
00214
00215
00216 s = qsNorm + dsc + cql/qsc;
00217 return (s);
00218 case SimpleKLParameter::NEGATIVEKLD:
00220
00221 assert(qm->scoreConstant()!=0);
00222 s = qsNorm + dsc - qmD;
00223
00224
00225
00226 return s;
00227
00228
00229 default:
00230 cerr << "unknown adjusted score method" << endl;
00231 return origScore;
00232 }
00233 }
00234
00235 };
00236
00238 class SimpleKLRetMethod : public TextQueryRetMethod {
00239 public:
00240
00242 SimpleKLRetMethod(const Index &dbIndex, const string &supportFileName,
00243 ScoreAccumulator &accumulator);
00244 virtual ~SimpleKLRetMethod();
00245
00246 virtual TextQueryRep *computeTextQueryRep(const TermQuery &qry) {
00247 return (new SimpleKLQueryModel(qry, ind));
00248 }
00249
00250 virtual DocumentRep *computeDocRep(DOCID_T docID);
00251
00252
00253 virtual ScoreFunction *scoreFunc() {
00254 return (scFunc);
00255 }
00256
00257 virtual void updateTextQuery(TextQueryRep &origRep,
00258 const DocIDSet &relDocs);
00259
00260 void setDocSmoothParam(SimpleKLParameter::DocSmoothParam &docSmthParam);
00261 void setQueryModelParam(SimpleKLParameter::QueryModelParam &queryModParam);
00262
00263 protected:
00264
00266 double *mcNorm;
00267
00269 double *docProbMass;
00271 COUNT_T *uniqueTermCount;
00273 UnigramLM *collectLM;
00275 DocUnigramCounter *collectLMCounter;
00277 SimpleKLScoreFunc *scFunc;
00278
00280
00281
00282 void computeMixtureFBModel(SimpleKLQueryModel &origRep,
00283 const DocIDSet & relDocs);
00285 void computeDivMinFBModel(SimpleKLQueryModel &origRep,
00286 const DocIDSet &relDocs);
00288 void computeMarkovChainFBModel(SimpleKLQueryModel &origRep,
00289 const DocIDSet &relDocs) ;
00291 void computeRM1FBModel(SimpleKLQueryModel &origRep,
00292 const DocIDSet & relDocs);
00294 void computeRM2FBModel(SimpleKLQueryModel &origRep,
00295 const DocIDSet & relDocs);
00297
00298 SimpleKLParameter::DocSmoothParam docParam;
00299 SimpleKLParameter::QueryModelParam qryParam;
00300
00302 void loadSupportFile();
00303 const string supportFile;
00304 };
00305
00306
00307 inline void SimpleKLRetMethod::setDocSmoothParam(SimpleKLParameter::DocSmoothParam &docSmthParam)
00308 {
00309 docParam = docSmthParam;
00310 loadSupportFile();
00311 }
00312
00313 inline void SimpleKLRetMethod::setQueryModelParam(SimpleKLParameter::QueryModelParam &queryModParam)
00314 {
00315 qryParam = queryModParam;
00316
00317
00318 scFunc->setScoreMethod(qryParam.adjScoreMethod);
00319 loadSupportFile();
00320 }
00321
00322 #endif