1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 2 /* */ 3 /* This file is part of the program and library */ 4 /* SCIP --- Solving Constraint Integer Programs */ 5 /* */ 6 /* Copyright (c) 2002-2023 Zuse Institute Berlin (ZIB) */ 7 /* */ 8 /* Licensed under the Apache License, Version 2.0 (the "License"); */ 9 /* you may not use this file except in compliance with the License. */ 10 /* You may obtain a copy of the License at */ 11 /* */ 12 /* http://www.apache.org/licenses/LICENSE-2.0 */ 13 /* */ 14 /* Unless required by applicable law or agreed to in writing, software */ 15 /* distributed under the License is distributed on an "AS IS" BASIS, */ 16 /* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ 17 /* See the License for the specific language governing permissions and */ 18 /* limitations under the License. */ 19 /* */ 20 /* You should have received a copy of the Apache-2.0 license */ 21 /* along with SCIP; see the file LICENSE. If not visit scipopt.org. */ 22 /* */ 23 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 24 25 /**@file bandit_epsgreedy.c 26 * @ingroup OTHER_CFILES 27 * @brief implementation of (a variant of) epsilon greedy bandit algorithm 28 * @author Gregor Hendel 29 */ 30 31 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/ 32 33 #include "scip/bandit.h" 34 #include "scip/bandit_epsgreedy.h" 35 #include "scip/pub_bandit.h" 36 #include "scip/pub_message.h" 37 #include "scip/pub_misc.h" 38 #include "scip/scip_bandit.h" 39 #include "scip/scip_mem.h" 40 #include "scip/scip_randnumgen.h" 41 42 #define BANDIT_NAME "eps-greedy" 43 #define EPSGREEDY_SMALL 1e-6 44 45 /* 46 * Data structures 47 */ 48 49 /** private data structure of epsilon greedy bandit algorithm */ 50 struct SCIP_BanditData 51 { 52 SCIP_Real* weights; /**< weights for every action */ 53 SCIP_Real* priorities; /**< saved priorities for tie breaking */ 54 int* sels; /**< individual number of selections per action */ 55 SCIP_Real eps; /**< epsilon parameter (between 0 and 1) to control epsilon greedy */ 56 SCIP_Bool usemodification; /**< TRUE if modified eps greedy should be used */ 57 SCIP_Real decayfactor; /**< the factor to reduce the weight of older observations if exponential decay is enabled */ 58 int avglim; /**< nonnegative limit on observation number before the exponential decay starts, 59 * only relevant if exponential decay is enabled 60 */ 61 int nselections; /**< counter for the number of selection calls */ 62 SCIP_Bool preferrecent; /**< should the weights be updated in an exponentially decaying way? */ 63 }; 64 65 /* 66 * Callback methods of bandit algorithm virtual function table 67 */ 68 69 /** callback to free bandit specific data structures */ 70 SCIP_DECL_BANDITFREE(SCIPbanditFreeEpsgreedy) 71 { /*lint --e{715}*/ 72 SCIP_BANDITDATA* banditdata; 73 int nactions; 74 75 assert(bandit != NULL); 76 77 banditdata = SCIPbanditGetData(bandit); 78 assert(banditdata != NULL); 79 assert(banditdata->weights != NULL); 80 nactions = SCIPbanditGetNActions(bandit); 81 82 BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions); 83 BMSfreeBlockMemoryArray(blkmem, &banditdata->priorities, nactions); 84 BMSfreeBlockMemoryArray(blkmem, &banditdata->sels, nactions); 85 BMSfreeBlockMemory(blkmem, &banditdata); 86 87 SCIPbanditSetData(bandit, NULL); 88 89 return SCIP_OKAY; 90 } 91 92 /** selection callback for bandit algorithm */ 93 SCIP_DECL_BANDITSELECT(SCIPbanditSelectEpsgreedy) 94 { /*lint --e{715}*/ 95 SCIP_BANDITDATA* banditdata; 96 SCIP_Real randnr; 97 SCIP_Real curreps; 98 SCIP_RANDNUMGEN* rng; 99 int nactions; 100 assert(bandit != NULL); 101 assert(selection != NULL); 102 103 banditdata = SCIPbanditGetData(bandit); 104 assert(banditdata != NULL); 105 rng = SCIPbanditGetRandnumgen(bandit); 106 assert(rng != NULL); 107 108 nactions = SCIPbanditGetNActions(bandit); 109 110 /* roll the dice to check if the best element should be picked, or an element at random */ 111 randnr = SCIPrandomGetReal(rng, 0.0, 1.0); 112 113 /* make epsilon decrease with an increasing number of selections */ 114 banditdata->nselections++; 115 curreps = banditdata->eps * sqrt((SCIP_Real)nactions/(SCIP_Real)banditdata->nselections); 116 117 /* select the best action seen so far */ 118 if( randnr >= curreps ) 119 { 120 SCIP_Real* weights = banditdata->weights; 121 SCIP_Real* priorities = banditdata->priorities; 122 int j; 123 SCIP_Real maxweight; 124 125 assert(weights != NULL); 126 assert(priorities != NULL); 127 128 /* pick the element with the largest reward */ 129 maxweight = weights[0]; 130 *selection = 0; 131 132 /* determine reward for every element */ 133 for( j = 1; j < nactions; ++j ) 134 { 135 SCIP_Real weight = weights[j]; 136 137 /* select the action that maximizes the reward, breaking ties by action priorities */ 138 if( maxweight < weight 139 || (weight >= maxweight - EPSGREEDY_SMALL && priorities[j] > priorities[*selection] ) ) 140 { 141 *selection = j; 142 maxweight = weight; 143 } 144 } 145 } 146 else if( ! banditdata->usemodification ) /* use normal eps greedy */ 147 { 148 /* play one of the actions at random */ 149 *selection = SCIPrandomGetInt(rng, 0, nactions - 1); 150 } 151 else /* pick an action w.r.t. the distributions defined by its weights */ 152 { 153 int j; 154 SCIP_Real sum; 155 SCIP_Real weightsum; 156 SCIP_Real* weights = banditdata->weights; 157 158 weightsum = 0.0; 159 for( j = 0; j < nactions; ++j ) 160 weightsum += banditdata->weights[j]; 161 162 /* pick a random number between 0.0 and sum of weights */ 163 randnr = SCIPrandomGetReal(rng, 0.0, weightsum); 164 165 /* pick action w.r.t. the weights distribution */ 166 sum = 0.0; 167 *selection = -1; 168 for( j = 0; j < nactions - 1; ++j ) 169 { 170 sum += weights[j]; 171 172 if( sum >= randnr ) 173 { 174 *selection = j; 175 break; 176 } 177 } 178 179 if( *selection < 0 ) 180 *selection = nactions - 1; 181 } 182 183 return SCIP_OKAY; 184 } 185 186 /** update callback for bandit algorithm */ 187 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateEpsgreedy) 188 { /*lint --e{715}*/ 189 SCIP_BANDITDATA* banditdata; 190 191 assert(bandit != NULL); 192 193 banditdata = SCIPbanditGetData(bandit); 194 assert(banditdata != NULL); 195 196 /* increase the selection count */ 197 ++banditdata->sels[selection]; 198 199 /* the very first observation is directly stored as weight for both average or exponential decay */ 200 if( banditdata->sels[selection] == 1 ) 201 banditdata->weights[selection] = score; 202 else 203 { 204 /* use exponentially decreasing weights for older observations */ 205 if( banditdata->preferrecent && banditdata->sels[selection] > banditdata->avglim ) 206 { 207 /* decrease old weights by decay factor */ 208 banditdata->weights[selection] *= banditdata->decayfactor; 209 banditdata->weights[selection] += (1.0 - banditdata->decayfactor) * score; 210 } 211 else 212 { 213 /* update average score */ 214 SCIP_Real diff = score - banditdata->weights[selection]; 215 banditdata->weights[selection] += diff / (SCIP_Real)(banditdata->sels[selection]); 216 } 217 } 218 219 return SCIP_OKAY; 220 } 221 222 /** reset callback for bandit algorithm */ 223 SCIP_DECL_BANDITRESET(SCIPbanditResetEpsgreedy) 224 { /*lint --e{715}*/ 225 SCIP_BANDITDATA* banditdata; 226 SCIP_Real* weights; 227 int w; 228 int nactions; 229 SCIP_RANDNUMGEN* rng; 230 231 assert(bandit != NULL); 232 233 banditdata = SCIPbanditGetData(bandit); 234 assert(banditdata != NULL); 235 236 weights = banditdata->weights; 237 nactions = SCIPbanditGetNActions(bandit); 238 assert(weights != NULL); 239 assert(banditdata->priorities != NULL); 240 assert(nactions > 0); 241 242 rng = SCIPbanditGetRandnumgen(bandit); 243 assert(rng != NULL); 244 245 /* alter priorities slightly to make them unique */ 246 if( priorities != NULL ) 247 { 248 for( w = 1; w < nactions; ++w ) 249 { 250 assert(priorities[w] >= 0); 251 banditdata->priorities[w] = priorities[w] + SCIPrandomGetReal(rng, -EPSGREEDY_SMALL, EPSGREEDY_SMALL); 252 } 253 } 254 else 255 { 256 /* use random priorities */ 257 for( w = 0; w < nactions; ++w ) 258 banditdata->priorities[w] = SCIPrandomGetReal(rng, 0.0, 1.0); 259 } 260 261 /* reset weights and selection counters to 0 */ 262 BMSclearMemoryArray(weights, nactions); 263 BMSclearMemoryArray(banditdata->sels, nactions); 264 265 banditdata->nselections = 0; 266 267 return SCIP_OKAY; 268 } 269 270 /* 271 * interface methods of the Epsilon Greedy bandit algorithm 272 */ 273 274 /** internal method to create and reset epsilon greedy bandit algorithm */ 275 SCIP_RETCODE SCIPbanditCreateEpsgreedy( 276 BMS_BLKMEM* blkmem, /**< block memory */ 277 BMS_BUFMEM* bufmem, /**< buffer memory */ 278 SCIP_BANDITVTABLE* vtable, /**< virtual function table with epsilon greedy callbacks */ 279 SCIP_BANDIT** epsgreedy, /**< pointer to store the epsilon greedy bandit algorithm */ 280 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */ 281 SCIP_Real eps, /**< parameter to increase probability for exploration between all actions */ 282 SCIP_Bool usemodification, /**< TRUE if modified eps greedy should be used */ 283 SCIP_Bool preferrecent, /**< should the weights be updated in an exponentially decaying way? */ 284 SCIP_Real decayfactor, /**< the factor to reduce the weight of older observations if exponential decay is enabled */ 285 int avglim, /**< nonnegative limit on observation number before the exponential decay starts, 286 * only relevant if exponential decay is enabled */ 287 int nactions, /**< the positive number of possible actions */ 288 unsigned int initseed /**< initial random seed */ 289 ) 290 { 291 SCIP_BANDITDATA* banditdata; 292 293 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) ); 294 assert(banditdata != NULL); 295 assert(eps >= 0.0); 296 297 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) ); 298 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->priorities, nactions) ); 299 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->sels, nactions) ); 300 banditdata->eps = eps; 301 banditdata->nselections = 0; 302 banditdata->usemodification = usemodification; 303 banditdata->preferrecent = preferrecent; 304 banditdata->decayfactor = decayfactor; 305 banditdata->avglim = avglim; 306 307 SCIP_CALL( SCIPbanditCreate(epsgreedy, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) ); 308 309 return SCIP_OKAY; 310 } 311 312 /** create and resets an epsilon greedy bandit algorithm */ 313 SCIP_RETCODE SCIPcreateBanditEpsgreedy( 314 SCIP* scip, /**< SCIP data structure */ 315 SCIP_BANDIT** epsgreedy, /**< pointer to store the epsilon greedy bandit algorithm */ 316 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */ 317 SCIP_Real eps, /**< parameter to increase probability for exploration between all actions */ 318 SCIP_Bool usemodification, /**< TRUE if modified eps greedy should be used */ 319 SCIP_Bool preferrecent, /**< should the weights be updated in an exponentially decaying way? */ 320 SCIP_Real decayfactor, /**< the factor to reduce the weight of older observations if exponential decay is enabled */ 321 int avglim, /**< nonnegative limit on observation number before the exponential decay starts, 322 * only relevant if exponential decay is enabled */ 323 int nactions, /**< the positive number of possible actions */ 324 unsigned int initseed /**< initial seed for random number generation */ 325 ) 326 { 327 SCIP_BANDITVTABLE* vtable; 328 assert(scip != NULL); 329 assert(epsgreedy != NULL); 330 331 vtable = SCIPfindBanditvtable(scip, BANDIT_NAME); 332 if( vtable == NULL ) 333 { 334 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME); 335 return SCIP_INVALIDDATA; 336 } 337 338 SCIP_CALL( SCIPbanditCreateEpsgreedy(SCIPblkmem(scip), SCIPbuffer(scip), vtable, epsgreedy, 339 priorities, eps, usemodification, preferrecent, decayfactor, avglim, nactions, SCIPinitializeRandomSeed(scip, initseed)) ); 340 341 return SCIP_OKAY; 342 } 343 344 /** get weights array of epsilon greedy bandit algorithm */ 345 SCIP_Real* SCIPgetWeightsEpsgreedy( 346 SCIP_BANDIT* epsgreedy /**< epsilon greedy bandit algorithm */ 347 ) 348 { 349 SCIP_BANDITDATA* banditdata; 350 assert(epsgreedy != NULL); 351 banditdata = SCIPbanditGetData(epsgreedy); 352 assert(banditdata != NULL); 353 354 return banditdata->weights; 355 } 356 357 /** set epsilon parameter of epsilon greedy bandit algorithm */ 358 void SCIPsetEpsilonEpsgreedy( 359 SCIP_BANDIT* epsgreedy, /**< epsilon greedy bandit algorithm */ 360 SCIP_Real eps /**< parameter to increase probability for exploration between all actions */ 361 ) 362 { 363 SCIP_BANDITDATA* banditdata; 364 assert(epsgreedy != NULL); 365 assert(eps >= 0); 366 367 banditdata = SCIPbanditGetData(epsgreedy); 368 369 banditdata->eps = eps; 370 } 371 372 373 /** creates the epsilon greedy bandit algorithm includes it in SCIP */ 374 SCIP_RETCODE SCIPincludeBanditvtableEpsgreedy( 375 SCIP* scip /**< SCIP data structure */ 376 ) 377 { 378 SCIP_BANDITVTABLE* banditvtable; 379 380 SCIP_CALL( SCIPincludeBanditvtable(scip, &banditvtable, BANDIT_NAME, 381 SCIPbanditFreeEpsgreedy, SCIPbanditSelectEpsgreedy, SCIPbanditUpdateEpsgreedy, SCIPbanditResetEpsgreedy) ); 382 383 return SCIP_OKAY; 384 } 385