1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 2 /* */ 3 /* This file is part of the program and library */ 4 /* SCIP --- Solving Constraint Integer Programs */ 5 /* */ 6 /* Copyright 2002-2022 Zuse Institute Berlin */ 7 /* */ 8 /* Licensed under the Apache License, Version 2.0 (the "License"); */ 9 /* you may not use this file except in compliance with the License. */ 10 /* You may obtain a copy of the License at */ 11 /* */ 12 /* http://www.apache.org/licenses/LICENSE-2.0 */ 13 /* */ 14 /* Unless required by applicable law or agreed to in writing, software */ 15 /* distributed under the License is distributed on an "AS IS" BASIS, */ 16 /* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ 17 /* See the License for the specific language governing permissions and */ 18 /* limitations under the License. */ 19 /* */ 20 /* You should have received a copy of the Apache-2.0 license */ 21 /* along with SCIP; see the file LICENSE. If not visit scipopt.org. */ 22 /* */ 23 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 24 25 /**@file bandit_exp3ix.c 26 * @ingroup OTHER_CFILES 27 * @brief methods for Exp.3-IX bandit selection 28 * @author Antonia Chmiela 29 */ 30 31 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/ 32 33 #include "scip/bandit.h" 34 #include "scip/bandit_exp3ix.h" 35 #include "scip/pub_bandit.h" 36 #include "scip/pub_message.h" 37 #include "scip/pub_misc.h" 38 #include "scip/scip_bandit.h" 39 #include "scip/scip_mem.h" 40 #include "scip/scip_randnumgen.h" 41 42 #define BANDIT_NAME "exp3ix" 43 44 /* 45 * Data structures 46 */ 47 48 /** implementation specific data of Exp.3 bandit algorithm */ 49 struct SCIP_BanditData 50 { 51 SCIP_Real* weights; /**< exponential weight for each arm */ 52 SCIP_Real weightsum; /**< the sum of all weights */ 53 int iter; /**< current iteration counter to compute parameters gamma_t and eta_t */ 54 }; 55 56 /* 57 * Local methods 58 */ 59 60 /* 61 * Callback methods of bandit algorithm 62 */ 63 64 /** callback to free bandit specific data structures */ 65 SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3IX) 66 { /*lint --e{715}*/ 67 SCIP_BANDITDATA* banditdata; 68 int nactions; 69 assert(bandit != NULL); 70 71 banditdata = SCIPbanditGetData(bandit); 72 assert(banditdata != NULL); 73 nactions = SCIPbanditGetNActions(bandit); 74 75 BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions); 76 77 BMSfreeBlockMemory(blkmem, &banditdata); 78 79 SCIPbanditSetData(bandit, NULL); 80 81 return SCIP_OKAY; 82 } 83 84 /** selection callback for bandit selector */ 85 SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3IX) 86 { /*lint --e{715}*/ 87 SCIP_BANDITDATA* banditdata; 88 SCIP_RANDNUMGEN* rng; 89 SCIP_Real* weights; 90 SCIP_Real weightsum; 91 int i; 92 int nactions; 93 SCIP_Real psum; 94 SCIP_Real randnr; 95 96 assert(bandit != NULL); 97 assert(selection != NULL); 98 99 banditdata = SCIPbanditGetData(bandit); 100 assert(banditdata != NULL); 101 rng = SCIPbanditGetRandnumgen(bandit); 102 assert(rng != NULL); 103 nactions = SCIPbanditGetNActions(bandit); 104 105 /* initialize some local variables to speed up probability computations */ 106 weightsum = banditdata->weightsum; 107 weights = banditdata->weights; 108 109 /* draw a random number between 0 and 1 */ 110 randnr = SCIPrandomGetReal(rng, 0.0, 1.0); 111 112 /* loop over probability distribution until rand is reached 113 * the loop terminates without looking at the last action, 114 * which is then selected automatically if the target probability 115 * is not reached earlier 116 */ 117 psum = 0.0; 118 for( i = 0; i < nactions - 1; ++i ) 119 { 120 SCIP_Real prob; 121 122 /* compute the probability for arm i */ 123 prob = weights[i] / weightsum; 124 psum += prob; 125 126 /* break and select element if target probability is reached */ 127 if( randnr <= psum ) 128 break; 129 } 130 131 /* select element i, which is the last action in case that the break statement hasn't been reached */ 132 *selection = i; 133 134 return SCIP_OKAY; 135 } 136 137 /** compute gamma_t */ 138 static 139 SCIP_Real SCIPcomputeGamma( 140 int nactions, /**< the positive number of actions for this bandit algorithm */ 141 int t /**< current iteration */ 142 ) 143 { 144 return SQRT(log((SCIP_Real)nactions) / (4.0 * (SCIP_Real)t * (SCIP_Real)nactions) ); 145 } 146 147 /** update callback for bandit algorithm */ 148 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3IX) 149 { /*lint --e{715}*/ 150 SCIP_BANDITDATA* banditdata; 151 SCIP_Real etaparam; 152 SCIP_Real lossestim; 153 SCIP_Real prob; 154 SCIP_Real weightsum; 155 SCIP_Real newweightsum; 156 SCIP_Real* weights; 157 SCIP_Real gammaparam; 158 int nactions; 159 160 assert(bandit != NULL); 161 162 banditdata = SCIPbanditGetData(bandit); 163 assert(banditdata != NULL); 164 nactions = SCIPbanditGetNActions(bandit); 165 166 assert(selection >= 0); 167 assert(selection < nactions); 168 169 weights = banditdata->weights; 170 weightsum = banditdata->weightsum; 171 newweightsum = weightsum; 172 gammaparam = SCIPcomputeGamma(nactions, banditdata->iter); 173 etaparam = 2.0 * gammaparam; 174 175 /* probability of selection */ 176 prob = weights[selection] / weightsum; 177 178 /* estimated loss */ 179 lossestim = (1.0 - score) / (prob + gammaparam); 180 assert(lossestim >= 0); 181 182 /* update the observation for the current arm */ 183 newweightsum -= weights[selection]; 184 weights[selection] *= exp(-etaparam * lossestim); 185 newweightsum += weights[selection]; 186 187 banditdata->weightsum = newweightsum; 188 189 /* increase iteration counter */ 190 banditdata->iter += 1; 191 192 return SCIP_OKAY; 193 } 194 195 /** reset callback for bandit algorithm */ 196 SCIP_DECL_BANDITRESET(SCIPbanditResetExp3IX) 197 { /*lint --e{715}*/ 198 SCIP_BANDITDATA* banditdata; 199 SCIP_Real* weights; 200 int nactions; 201 int i; 202 203 assert(bandit != NULL); 204 205 banditdata = SCIPbanditGetData(bandit); 206 assert(banditdata != NULL); 207 nactions = SCIPbanditGetNActions(bandit); 208 weights = banditdata->weights; 209 210 assert(nactions > 0); 211 212 /* initialize all weights with 1.0 */ 213 for( i = 0; i < nactions; ++i ) 214 weights[i] = 1.0; 215 216 banditdata->weightsum = (SCIP_Real)nactions; 217 218 /* set iteration counter to 1 */ 219 banditdata->iter = 1; 220 221 return SCIP_OKAY; 222 } 223 224 225 /* 226 * bandit algorithm specific interface methods 227 */ 228 229 /** direct bandit creation method for the core where no SCIP pointer is available */ 230 SCIP_RETCODE SCIPbanditCreateExp3IX( 231 BMS_BLKMEM* blkmem, /**< block memory data structure */ 232 BMS_BUFMEM* bufmem, /**< buffer memory */ 233 SCIP_BANDITVTABLE* vtable, /**< virtual function table for callback functions of Exp.3-IX */ 234 SCIP_BANDIT** exp3ix, /**< pointer to store bandit algorithm */ 235 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */ 236 int nactions, /**< the positive number of actions for this bandit algorithm */ 237 unsigned int initseed /**< initial random seed */ 238 ) 239 { 240 SCIP_BANDITDATA* banditdata; 241 242 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) ); 243 assert(banditdata != NULL); 244 245 banditdata->iter = 1; 246 247 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) ); 248 249 SCIP_CALL( SCIPbanditCreate(exp3ix, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) ); 250 251 return SCIP_OKAY; 252 } 253 254 /** creates and resets an Exp.3-IX bandit algorithm using \p scip pointer */ 255 SCIP_RETCODE SCIPcreateBanditExp3IX( 256 SCIP* scip, /**< SCIP data structure */ 257 SCIP_BANDIT** exp3ix, /**< pointer to store bandit algorithm */ 258 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */ 259 int nactions, /**< the positive number of actions for this bandit algorithm */ 260 unsigned int initseed /**< initial seed for random number generation */ 261 ) 262 { 263 SCIP_BANDITVTABLE* vtable; 264 265 vtable = SCIPfindBanditvtable(scip, BANDIT_NAME); 266 if( vtable == NULL ) 267 { 268 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME); 269 return SCIP_INVALIDDATA; 270 } 271 272 SCIP_CALL( SCIPbanditCreateExp3IX(SCIPblkmem(scip), SCIPbuffer(scip), vtable, exp3ix, 273 priorities, nactions, SCIPinitializeRandomSeed(scip, initseed)) ); 274 275 return SCIP_OKAY; 276 } 277 278 /** returns probability to play an action */ 279 SCIP_Real SCIPgetProbabilityExp3IX( 280 SCIP_BANDIT* exp3ix, /**< bandit algorithm */ 281 int action /**< index of the requested action */ 282 ) 283 { 284 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3ix); 285 286 assert(banditdata->weightsum > 0.0); 287 assert(SCIPbanditGetNActions(exp3ix) > 0); 288 289 return banditdata->weights[action] / banditdata->weightsum; 290 } 291 292 /** include virtual function table for Exp.3-IX bandit algorithms */ 293 SCIP_RETCODE SCIPincludeBanditvtableExp3IX( 294 SCIP* scip /**< SCIP data structure */ 295 ) 296 { 297 SCIP_BANDITVTABLE* vtable; 298 299 SCIP_CALL( SCIPincludeBanditvtable(scip, &vtable, BANDIT_NAME, 300 SCIPbanditFreeExp3IX, SCIPbanditSelectExp3IX, SCIPbanditUpdateExp3IX, SCIPbanditResetExp3IX) ); 301 assert(vtable != NULL); 302 303 return SCIP_OKAY; 304 } 305