1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 2 /* */ 3 /* This file is part of the program and library */ 4 /* SCIP --- Solving Constraint Integer Programs */ 5 /* */ 6 /* Copyright (c) 2002-2023 Zuse Institute Berlin (ZIB) */ 7 /* */ 8 /* Licensed under the Apache License, Version 2.0 (the "License"); */ 9 /* you may not use this file except in compliance with the License. */ 10 /* You may obtain a copy of the License at */ 11 /* */ 12 /* http://www.apache.org/licenses/LICENSE-2.0 */ 13 /* */ 14 /* Unless required by applicable law or agreed to in writing, software */ 15 /* distributed under the License is distributed on an "AS IS" BASIS, */ 16 /* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ 17 /* See the License for the specific language governing permissions and */ 18 /* limitations under the License. */ 19 /* */ 20 /* You should have received a copy of the Apache-2.0 license */ 21 /* along with SCIP; see the file LICENSE. If not visit scipopt.org. */ 22 /* */ 23 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 24 25 /**@file bandit_exp3.c 26 * @ingroup OTHER_CFILES 27 * @brief methods for Exp.3 bandit selection 28 * @author Gregor Hendel 29 */ 30 31 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/ 32 33 #include "scip/bandit.h" 34 #include "scip/bandit_exp3.h" 35 #include "scip/pub_bandit.h" 36 #include "scip/pub_message.h" 37 #include "scip/pub_misc.h" 38 #include "scip/scip_bandit.h" 39 #include "scip/scip_mem.h" 40 #include "scip/scip_randnumgen.h" 41 42 #define BANDIT_NAME "exp3" 43 #define NUMTOL 1e-6 44 45 /* 46 * Data structures 47 */ 48 49 /** implementation specific data of Exp.3 bandit algorithm */ 50 struct SCIP_BanditData 51 { 52 SCIP_Real* weights; /**< exponential weight for each arm */ 53 SCIP_Real weightsum; /**< the sum of all weights */ 54 SCIP_Real gamma; /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */ 55 SCIP_Real beta; /**< gain offset between 0 and 1 at every observation */ 56 }; 57 58 /* 59 * Local methods 60 */ 61 62 /* 63 * Callback methods of bandit algorithm 64 */ 65 66 /** callback to free bandit specific data structures */ 67 SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3) 68 { /*lint --e{715}*/ 69 SCIP_BANDITDATA* banditdata; 70 int nactions; 71 assert(bandit != NULL); 72 73 banditdata = SCIPbanditGetData(bandit); 74 assert(banditdata != NULL); 75 nactions = SCIPbanditGetNActions(bandit); 76 77 BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions); 78 79 BMSfreeBlockMemory(blkmem, &banditdata); 80 81 SCIPbanditSetData(bandit, NULL); 82 83 return SCIP_OKAY; 84 } 85 86 /** selection callback for bandit selector */ 87 SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3) 88 { /*lint --e{715}*/ 89 SCIP_BANDITDATA* banditdata; 90 SCIP_RANDNUMGEN* rng; 91 SCIP_Real randnr; 92 SCIP_Real psum; 93 SCIP_Real gammaoverk; 94 SCIP_Real oneminusgamma; 95 SCIP_Real* weights; 96 SCIP_Real weightsum; 97 int i; 98 int nactions; 99 100 assert(bandit != NULL); 101 assert(selection != NULL); 102 103 banditdata = SCIPbanditGetData(bandit); 104 assert(banditdata != NULL); 105 rng = SCIPbanditGetRandnumgen(bandit); 106 assert(rng != NULL); 107 nactions = SCIPbanditGetNActions(bandit); 108 109 /* draw a random number between 0 and 1 */ 110 randnr = SCIPrandomGetReal(rng, 0.0, 1.0); 111 112 /* initialize some local variables to speed up probability computations */ 113 oneminusgamma = 1 - banditdata->gamma; 114 gammaoverk = banditdata->gamma / (SCIP_Real)nactions; 115 weightsum = banditdata->weightsum; 116 weights = banditdata->weights; 117 psum = 0.0; 118 119 /* loop over probability distribution until rand is reached 120 * the loop terminates without looking at the last action, 121 * which is then selected automatically if the target probability 122 * is not reached earlier 123 */ 124 for( i = 0; i < nactions - 1; ++i ) 125 { 126 SCIP_Real prob; 127 128 /* compute the probability for arm i as convex kombination of a uniform distribution and a weighted distribution */ 129 prob = oneminusgamma * weights[i] / weightsum + gammaoverk; 130 psum += prob; 131 132 /* break and select element if target probability is reached */ 133 if( randnr <= psum ) 134 break; 135 } 136 137 /* select element i, which is the last action in case that the break statement hasn't been reached */ 138 *selection = i; 139 140 return SCIP_OKAY; 141 } 142 143 /** update callback for bandit algorithm */ 144 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3) 145 { /*lint --e{715}*/ 146 SCIP_BANDITDATA* banditdata; 147 SCIP_Real eta; 148 SCIP_Real gainestim; 149 SCIP_Real beta; 150 SCIP_Real weightsum; 151 SCIP_Real newweightsum; 152 SCIP_Real* weights; 153 SCIP_Real oneminusgamma; 154 SCIP_Real gammaoverk; 155 int nactions; 156 157 assert(bandit != NULL); 158 159 banditdata = SCIPbanditGetData(bandit); 160 assert(banditdata != NULL); 161 nactions = SCIPbanditGetNActions(bandit); 162 163 assert(selection >= 0); 164 assert(selection < nactions); 165 166 /* the learning rate eta */ 167 eta = 1.0 / (SCIP_Real)nactions; 168 169 beta = banditdata->beta; 170 oneminusgamma = 1.0 - banditdata->gamma; 171 gammaoverk = banditdata->gamma * eta; 172 weights = banditdata->weights; 173 weightsum = banditdata->weightsum; 174 newweightsum = weightsum; 175 176 /* if beta is zero, only the observation for the current arm needs an update */ 177 if( EPSZ(beta, NUMTOL) ) 178 { 179 SCIP_Real probai; 180 probai = oneminusgamma * weights[selection] / weightsum + gammaoverk; 181 182 assert(probai > 0.0); 183 184 gainestim = score / probai; 185 newweightsum -= weights[selection]; 186 weights[selection] *= exp(eta * gainestim); 187 newweightsum += weights[selection]; 188 } 189 else 190 { 191 int j; 192 newweightsum = 0.0; 193 194 /* loop over all items and update their weights based on the influence of the beta parameter */ 195 for( j = 0; j < nactions; ++j ) 196 { 197 SCIP_Real probaj; 198 probaj = oneminusgamma * weights[j] / weightsum + gammaoverk; 199 200 assert(probaj > 0.0); 201 202 /* consider the score only for the chosen arm i, use constant beta offset otherwise */ 203 if( j == selection ) 204 gainestim = (score + beta) / probaj; 205 else 206 gainestim = beta / probaj; 207 208 weights[j] *= exp(eta * gainestim); 209 newweightsum += weights[j]; 210 } 211 } 212 213 banditdata->weightsum = newweightsum; 214 215 return SCIP_OKAY; 216 } 217 218 /** reset callback for bandit algorithm */ 219 SCIP_DECL_BANDITRESET(SCIPbanditResetExp3) 220 { /*lint --e{715}*/ 221 SCIP_BANDITDATA* banditdata; 222 SCIP_Real* weights; 223 int nactions; 224 int i; 225 226 assert(bandit != NULL); 227 228 banditdata = SCIPbanditGetData(bandit); 229 assert(banditdata != NULL); 230 nactions = SCIPbanditGetNActions(bandit); 231 weights = banditdata->weights; 232 233 assert(nactions > 0); 234 235 banditdata->weightsum = (1.0 + NUMTOL) * (SCIP_Real)nactions; 236 237 /* in case of priorities, weights are normalized to sum up to nactions */ 238 if( priorities != NULL ) 239 { 240 SCIP_Real normalization; 241 SCIP_Real priosum; 242 priosum = 0.0; 243 244 /* compute sum of priorities */ 245 for( i = 0; i < nactions; ++i ) 246 { 247 assert(priorities[i] >= 0); 248 priosum += priorities[i]; 249 } 250 251 /* if there are positive priorities, normalize the weights */ 252 if( priosum > 0.0 ) 253 { 254 normalization = nactions / priosum; 255 for( i = 0; i < nactions; ++i ) 256 weights[i] = (priorities[i] * normalization) + NUMTOL; 257 } 258 else 259 { 260 /* use uniform distribution in case of all priorities being 0.0 */ 261 for( i = 0; i < nactions; ++i ) 262 weights[i] = 1.0 + NUMTOL; 263 } 264 } 265 else 266 { 267 /* use uniform distribution in case of unspecified priorities */ 268 for( i = 0; i < nactions; ++i ) 269 weights[i] = 1.0 + NUMTOL; 270 } 271 272 return SCIP_OKAY; 273 } 274 275 276 /* 277 * bandit algorithm specific interface methods 278 */ 279 280 /** direct bandit creation method for the core where no SCIP pointer is available */ 281 SCIP_RETCODE SCIPbanditCreateExp3( 282 BMS_BLKMEM* blkmem, /**< block memory data structure */ 283 BMS_BUFMEM* bufmem, /**< buffer memory */ 284 SCIP_BANDITVTABLE* vtable, /**< virtual function table for callback functions of Exp.3 */ 285 SCIP_BANDIT** exp3, /**< pointer to store bandit algorithm */ 286 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */ 287 SCIP_Real gammaparam, /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */ 288 SCIP_Real beta, /**< gain offset between 0 and 1 at every observation */ 289 int nactions, /**< the positive number of actions for this bandit algorithm */ 290 unsigned int initseed /**< initial random seed */ 291 ) 292 { 293 SCIP_BANDITDATA* banditdata; 294 295 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) ); 296 assert(banditdata != NULL); 297 298 banditdata->gamma = gammaparam; 299 banditdata->beta = beta; 300 assert(gammaparam >= 0 && gammaparam <= 1); 301 assert(beta >= 0 && beta <= 1); 302 303 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) ); 304 305 SCIP_CALL( SCIPbanditCreate(exp3, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) ); 306 307 return SCIP_OKAY; 308 } 309 310 /** creates and resets an Exp.3 bandit algorithm using \p scip pointer */ 311 SCIP_RETCODE SCIPcreateBanditExp3( 312 SCIP* scip, /**< SCIP data structure */ 313 SCIP_BANDIT** exp3, /**< pointer to store bandit algorithm */ 314 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */ 315 SCIP_Real gammaparam, /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */ 316 SCIP_Real beta, /**< gain offset between 0 and 1 at every observation */ 317 int nactions, /**< the positive number of actions for this bandit algorithm */ 318 unsigned int initseed /**< initial seed for random number generation */ 319 ) 320 { 321 SCIP_BANDITVTABLE* vtable; 322 323 vtable = SCIPfindBanditvtable(scip, BANDIT_NAME); 324 if( vtable == NULL ) 325 { 326 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME); 327 return SCIP_INVALIDDATA; 328 } 329 330 SCIP_CALL( SCIPbanditCreateExp3(SCIPblkmem(scip), SCIPbuffer(scip), vtable, exp3, 331 priorities, gammaparam, beta, nactions, SCIPinitializeRandomSeed(scip, initseed)) ); 332 333 return SCIP_OKAY; 334 } 335 336 /** set gamma parameter of Exp.3 bandit algorithm to increase weight of uniform distribution */ 337 void SCIPsetGammaExp3( 338 SCIP_BANDIT* exp3, /**< bandit algorithm */ 339 SCIP_Real gammaparam /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */ 340 ) 341 { 342 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3); 343 344 assert(gammaparam >= 0 && gammaparam <= 1); 345 346 banditdata->gamma = gammaparam; 347 } 348 349 /** set beta parameter of Exp.3 bandit algorithm to increase gain offset for actions that were not played */ 350 void SCIPsetBetaExp3( 351 SCIP_BANDIT* exp3, /**< bandit algorithm */ 352 SCIP_Real beta /**< gain offset between 0 and 1 at every observation */ 353 ) 354 { 355 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3); 356 357 assert(beta >= 0 && beta <= 1); 358 359 banditdata->beta = beta; 360 } 361 362 /** returns probability to play an action */ 363 SCIP_Real SCIPgetProbabilityExp3( 364 SCIP_BANDIT* exp3, /**< bandit algorithm */ 365 int action /**< index of the requested action */ 366 ) 367 { 368 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3); 369 370 assert(banditdata->weightsum > 0.0); 371 assert(SCIPbanditGetNActions(exp3) > 0); 372 373 return (1.0 - banditdata->gamma) * banditdata->weights[action] / banditdata->weightsum + banditdata->gamma / (SCIP_Real)SCIPbanditGetNActions(exp3); 374 } 375 376 /** include virtual function table for Exp.3 bandit algorithms */ 377 SCIP_RETCODE SCIPincludeBanditvtableExp3( 378 SCIP* scip /**< SCIP data structure */ 379 ) 380 { 381 SCIP_BANDITVTABLE* vtable; 382 383 SCIP_CALL( SCIPincludeBanditvtable(scip, &vtable, BANDIT_NAME, 384 SCIPbanditFreeExp3, SCIPbanditSelectExp3, SCIPbanditUpdateExp3, SCIPbanditResetExp3) ); 385 assert(vtable != NULL); 386 387 return SCIP_OKAY; 388 } 389