1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 2 /* */ 3 /* This file is part of the program and library */ 4 /* SCIP --- Solving Constraint Integer Programs */ 5 /* */ 6 /* Copyright (c) 2002-2023 Zuse Institute Berlin (ZIB) */ 7 /* */ 8 /* Licensed under the Apache License, Version 2.0 (the "License"); */ 9 /* you may not use this file except in compliance with the License. */ 10 /* You may obtain a copy of the License at */ 11 /* */ 12 /* http://www.apache.org/licenses/LICENSE-2.0 */ 13 /* */ 14 /* Unless required by applicable law or agreed to in writing, software */ 15 /* distributed under the License is distributed on an "AS IS" BASIS, */ 16 /* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ 17 /* See the License for the specific language governing permissions and */ 18 /* limitations under the License. */ 19 /* */ 20 /* You should have received a copy of the Apache-2.0 license */ 21 /* along with SCIP; see the file LICENSE. If not visit scipopt.org. */ 22 /* */ 23 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 24 25 /**@file bandit_ucb.c 26 * @ingroup OTHER_CFILES 27 * @brief methods for UCB bandit selection 28 * @author Gregor Hendel 29 */ 30 31 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/ 32 33 #include "scip/bandit.h" 34 #include "scip/bandit_ucb.h" 35 #include "scip/pub_bandit.h" 36 #include "scip/pub_message.h" 37 #include "scip/pub_misc.h" 38 #include "scip/pub_misc_sort.h" 39 #include "scip/scip_bandit.h" 40 #include "scip/scip_mem.h" 41 #include "scip/scip_randnumgen.h" 42 43 44 #define BANDIT_NAME "ucb" 45 #define NUMEPS 1e-6 46 47 /* 48 * Data structures 49 */ 50 51 /** implementation specific data of UCB bandit algorithm */ 52 struct SCIP_BanditData 53 { 54 int nselections; /**< counter for the number of selections */ 55 int* counter; /**< array of counters how often every action has been chosen */ 56 int* startperm; /**< indices for starting permutation */ 57 SCIP_Real* meanscores; /**< array of average scores for the actions */ 58 SCIP_Real alpha; /**< parameter to increase confidence width */ 59 }; 60 61 62 /* 63 * Local methods 64 */ 65 66 /** data reset method */ 67 static 68 SCIP_RETCODE dataReset( 69 BMS_BUFMEM* bufmem, /**< buffer memory */ 70 SCIP_BANDIT* ucb, /**< ucb bandit algorithm */ 71 SCIP_BANDITDATA* banditdata, /**< UCB bandit data structure */ 72 SCIP_Real* priorities, /**< priorities for start permutation, or NULL */ 73 int nactions /**< number of actions */ 74 ) 75 { 76 int i; 77 SCIP_RANDNUMGEN* rng; 78 79 assert(bufmem != NULL); 80 assert(ucb != NULL); 81 assert(nactions > 0); 82 83 /* clear counters and scores */ 84 BMSclearMemoryArray(banditdata->counter, nactions); 85 BMSclearMemoryArray(banditdata->meanscores, nactions); 86 banditdata->nselections = 0; 87 88 rng = SCIPbanditGetRandnumgen(ucb); 89 assert(rng != NULL); 90 91 /* initialize start permutation as identity */ 92 for( i = 0; i < nactions; ++i ) 93 banditdata->startperm[i] = i; 94 95 /* prepare the start permutation in decreasing order of priority */ 96 if( priorities != NULL ) 97 { 98 SCIP_Real* prioritycopy; 99 100 SCIP_ALLOC( BMSduplicateBufferMemoryArray(bufmem, &prioritycopy, priorities, nactions) ); 101 102 /* randomly wiggle priorities a little bit to make them unique */ 103 for( i = 0; i < nactions; ++i ) 104 prioritycopy[i] += SCIPrandomGetReal(rng, -NUMEPS, NUMEPS); 105 106 SCIPsortDownRealInt(prioritycopy, banditdata->startperm, nactions); 107 108 BMSfreeBufferMemoryArray(bufmem, &prioritycopy); 109 } 110 else 111 { 112 /* use a random start permutation */ 113 SCIPrandomPermuteIntArray(rng, banditdata->startperm, 0, nactions); 114 } 115 116 return SCIP_OKAY; 117 } 118 119 120 /* 121 * Callback methods of bandit algorithm 122 */ 123 124 /** callback to free bandit specific data structures */ 125 SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb) 126 { /*lint --e{715}*/ 127 SCIP_BANDITDATA* banditdata; 128 int nactions; 129 assert(bandit != NULL); 130 131 banditdata = SCIPbanditGetData(bandit); 132 assert(banditdata != NULL); 133 nactions = SCIPbanditGetNActions(bandit); 134 135 BMSfreeBlockMemoryArray(blkmem, &banditdata->counter, nactions); 136 BMSfreeBlockMemoryArray(blkmem, &banditdata->startperm, nactions); 137 BMSfreeBlockMemoryArray(blkmem, &banditdata->meanscores, nactions); 138 BMSfreeBlockMemory(blkmem, &banditdata); 139 140 SCIPbanditSetData(bandit, NULL); 141 142 return SCIP_OKAY; 143 } 144 145 /** selection callback for bandit selector */ 146 SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb) 147 { /*lint --e{715}*/ 148 SCIP_BANDITDATA* banditdata; 149 int nactions; 150 int* counter; 151 152 assert(bandit != NULL); 153 assert(selection != NULL); 154 155 banditdata = SCIPbanditGetData(bandit); 156 assert(banditdata != NULL); 157 nactions = SCIPbanditGetNActions(bandit); 158 159 counter = banditdata->counter; 160 /* select the next uninitialized action from the start permutation */ 161 if( banditdata->nselections < nactions ) 162 { 163 *selection = banditdata->startperm[banditdata->nselections]; 164 assert(counter[*selection] == 0); 165 } 166 else 167 { 168 /* select the action with the highest upper confidence bound */ 169 SCIP_Real* meanscores; 170 SCIP_Real widthfactor; 171 SCIP_Real maxucb; 172 int i; 173 SCIP_RANDNUMGEN* rng = SCIPbanditGetRandnumgen(bandit); 174 meanscores = banditdata->meanscores; 175 176 assert(rng != NULL); 177 assert(meanscores != NULL); 178 179 /* compute the confidence width factor that is common for all actions */ 180 /* cppcheck-suppress unpreciseMathCall */ 181 widthfactor = banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections); 182 widthfactor = sqrt(widthfactor); 183 maxucb = -1.0; 184 185 /* loop over the actions and determine the maximum upper confidence bound. 186 * The upper confidence bound of an action is the sum of its mean score 187 * plus a confidence term that decreases with increasing number of observations of 188 * this action. 189 */ 190 for( i = 0; i < nactions; ++i ) 191 { 192 SCIP_Real uppercb; 193 SCIP_Real rootcount; 194 assert(counter[i] > 0); 195 196 /* compute the upper confidence bound for action i */ 197 uppercb = meanscores[i]; 198 rootcount = sqrt((SCIP_Real)counter[i]); 199 uppercb += widthfactor / rootcount; 200 assert(uppercb > 0); 201 202 /* update maximum, breaking ties uniformly at random */ 203 if( EPSGT(uppercb, maxucb, NUMEPS) || (EPSEQ(uppercb, maxucb, NUMEPS) && SCIPrandomGetReal(rng, 0.0, 1.0) >= 0.5) ) 204 { 205 maxucb = uppercb; 206 *selection = i; 207 } 208 } 209 } 210 211 assert(*selection >= 0); 212 assert(*selection < nactions); 213 214 return SCIP_OKAY; 215 } 216 217 /** update callback for bandit algorithm */ 218 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb) 219 { /*lint --e{715}*/ 220 SCIP_BANDITDATA* banditdata; 221 SCIP_Real delta; 222 223 assert(bandit != NULL); 224 225 banditdata = SCIPbanditGetData(bandit); 226 assert(banditdata != NULL); 227 assert(selection >= 0); 228 assert(selection < SCIPbanditGetNActions(bandit)); 229 230 /* increase the mean by the incremental formula: A_n = A_n-1 + 1/n (a_n - A_n-1) */ 231 delta = score - banditdata->meanscores[selection]; 232 ++banditdata->counter[selection]; 233 banditdata->meanscores[selection] += delta / (SCIP_Real)banditdata->counter[selection]; 234 235 banditdata->nselections++; 236 237 return SCIP_OKAY; 238 } 239 240 /** reset callback for bandit algorithm */ 241 SCIP_DECL_BANDITRESET(SCIPbanditResetUcb) 242 { /*lint --e{715}*/ 243 SCIP_BANDITDATA* banditdata; 244 int nactions; 245 246 assert(bufmem != NULL); 247 assert(bandit != NULL); 248 249 banditdata = SCIPbanditGetData(bandit); 250 assert(banditdata != NULL); 251 nactions = SCIPbanditGetNActions(bandit); 252 253 /* call the data reset for the given priorities */ 254 SCIP_CALL( dataReset(bufmem, bandit, banditdata, priorities, nactions) ); 255 256 return SCIP_OKAY; 257 } 258 259 /* 260 * bandit algorithm specific interface methods 261 */ 262 263 /** returns the upper confidence bound of a selected action */ 264 SCIP_Real SCIPgetConfidenceBoundUcb( 265 SCIP_BANDIT* ucb, /**< UCB bandit algorithm */ 266 int action /**< index of the queried action */ 267 ) 268 { 269 SCIP_Real uppercb; 270 SCIP_BANDITDATA* banditdata; 271 int nactions; 272 273 assert(ucb != NULL); 274 banditdata = SCIPbanditGetData(ucb); 275 nactions = SCIPbanditGetNActions(ucb); 276 assert(action < nactions); 277 278 /* since only scores between 0 and 1 are allowed, 1.0 is a sure upper confidence bound */ 279 if( banditdata->nselections < nactions ) 280 return 1.0; 281 282 /* the bandit algorithm must have picked every action once */ 283 assert(banditdata->counter[action] > 0); 284 uppercb = banditdata->meanscores[action]; 285 286 /* cppcheck-suppress unpreciseMathCall */ 287 uppercb += sqrt(banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections) / (SCIP_Real)banditdata->counter[action]); 288 289 return uppercb; 290 } 291 292 /** return start permutation of the UCB bandit algorithm */ 293 int* SCIPgetStartPermutationUcb( 294 SCIP_BANDIT* ucb /**< UCB bandit algorithm */ 295 ) 296 { 297 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(ucb); 298 299 assert(banditdata != NULL); 300 301 return banditdata->startperm; 302 } 303 304 /** internal method to create and reset UCB bandit algorithm */ 305 SCIP_RETCODE SCIPbanditCreateUcb( 306 BMS_BLKMEM* blkmem, /**< block memory */ 307 BMS_BUFMEM* bufmem, /**< buffer memory */ 308 SCIP_BANDITVTABLE* vtable, /**< virtual function table for UCB bandit algorithm */ 309 SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */ 310 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */ 311 SCIP_Real alpha, /**< parameter to increase confidence width */ 312 int nactions, /**< the positive number of actions for this bandit algorithm */ 313 unsigned int initseed /**< initial random seed */ 314 ) 315 { 316 SCIP_BANDITDATA* banditdata; 317 318 if( alpha < 0.0 ) 319 { 320 SCIPerrorMessage("UCB requires nonnegative alpha parameter, have %f\n", alpha); 321 return SCIP_INVALIDDATA; 322 } 323 324 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) ); 325 assert(banditdata != NULL); 326 327 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->counter, nactions) ); 328 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->startperm, nactions) ); 329 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->meanscores, nactions) ); 330 331 banditdata->alpha = alpha; 332 333 SCIP_CALL( SCIPbanditCreate(ucb, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) ); 334 335 return SCIP_OKAY; 336 } 337 338 /** create and reset UCB bandit algorithm */ 339 SCIP_RETCODE SCIPcreateBanditUcb( 340 SCIP* scip, /**< SCIP data structure */ 341 SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */ 342 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */ 343 SCIP_Real alpha, /**< parameter to increase confidence width */ 344 int nactions, /**< the positive number of actions for this bandit algorithm */ 345 unsigned int initseed /**< initial random number seed */ 346 ) 347 { 348 SCIP_BANDITVTABLE* vtable; 349 350 vtable = SCIPfindBanditvtable(scip, BANDIT_NAME); 351 if( vtable == NULL ) 352 { 353 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME); 354 return SCIP_INVALIDDATA; 355 } 356 357 SCIP_CALL( SCIPbanditCreateUcb(SCIPblkmem(scip), SCIPbuffer(scip), vtable, ucb, 358 priorities, alpha, nactions, SCIPinitializeRandomSeed(scip, initseed)) ); 359 360 return SCIP_OKAY; 361 } 362 363 /** include virtual function table for UCB bandit algorithms */ 364 SCIP_RETCODE SCIPincludeBanditvtableUcb( 365 SCIP* scip /**< SCIP data structure */ 366 ) 367 { 368 SCIP_BANDITVTABLE* vtable; 369 370 SCIP_CALL( SCIPincludeBanditvtable(scip, &vtable, BANDIT_NAME, 371 SCIPbanditFreeUcb, SCIPbanditSelectUcb, SCIPbanditUpdateUcb, SCIPbanditResetUcb) ); 372 assert(vtable != NULL); 373 374 return SCIP_OKAY; 375 } 376