1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 2 /* */ 3 /* This file is part of the program and library */ 4 /* SCIP --- Solving Constraint Integer Programs */ 5 /* */ 6 /* Copyright (c) 2002-2023 Zuse Institute Berlin (ZIB) */ 7 /* */ 8 /* Licensed under the Apache License, Version 2.0 (the "License"); */ 9 /* you may not use this file except in compliance with the License. */ 10 /* You may obtain a copy of the License at */ 11 /* */ 12 /* http://www.apache.org/licenses/LICENSE-2.0 */ 13 /* */ 14 /* Unless required by applicable law or agreed to in writing, software */ 15 /* distributed under the License is distributed on an "AS IS" BASIS, */ 16 /* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ 17 /* See the License for the specific language governing permissions and */ 18 /* limitations under the License. */ 19 /* */ 20 /* You should have received a copy of the Apache-2.0 license */ 21 /* along with SCIP; see the file LICENSE. If not visit scipopt.org. */ 22 /* */ 23 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 24 25 /**@file bandit.c 26 * @ingroup OTHER_CFILES 27 * @brief internal API of bandit algorithms and bandit virtual function tables 28 * @author Gregor Hendel 29 */ 30 31 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/ 32 33 #include <assert.h> 34 #include <string.h> 35 #include "scip/bandit.h" 36 #include "scip/pub_bandit.h" 37 #include "scip/struct_bandit.h" 38 #include "scip/struct_set.h" 39 #include "scip/set.h" 40 41 /** creates and resets bandit algorithm */ 42 SCIP_RETCODE SCIPbanditCreate( 43 SCIP_BANDIT** bandit, /**< pointer to bandit algorithm data structure */ 44 SCIP_BANDITVTABLE* banditvtable, /**< virtual table for this bandit algorithm */ 45 BMS_BLKMEM* blkmem, /**< block memory for parameter settings */ 46 BMS_BUFMEM* bufmem, /**< buffer memory */ 47 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */ 48 int nactions, /**< the positive number of actions for this bandit */ 49 unsigned int initseed, /**< initial seed for random number generation */ 50 SCIP_BANDITDATA* banditdata /**< algorithm specific bandit data */ 51 ) 52 { 53 SCIP_BANDIT* banditptr; 54 assert(bandit != NULL); 55 assert(banditvtable != NULL); 56 57 /* the number of actions must be positive */ 58 if( nactions <= 0 ) 59 { 60 SCIPerrorMessage("Cannot create bandit selector with %d <= 0 actions\n", nactions); 61 62 return SCIP_INVALIDDATA; 63 } 64 65 SCIP_ALLOC( BMSallocBlockMemory(blkmem, bandit) ); 66 assert(*bandit != NULL); 67 banditptr = *bandit; 68 banditptr->vtable = banditvtable; 69 banditptr->data = banditdata; 70 banditptr->nactions = nactions; 71 72 SCIP_CALL( SCIPrandomCreate(&banditptr->rng, blkmem, initseed) ); 73 74 SCIP_CALL( SCIPbanditReset(bufmem, banditptr, priorities, initseed) ); 75 76 return SCIP_OKAY; 77 } 78 79 /** calls destructor and frees memory of bandit algorithm */ 80 SCIP_RETCODE SCIPbanditFree( 81 BMS_BLKMEM* blkmem, /**< block memory */ 82 SCIP_BANDIT** bandit /**< pointer to bandit algorithm data structure */ 83 ) 84 { 85 SCIP_BANDIT* banditptr; 86 SCIP_BANDITVTABLE* vtable; 87 assert(bandit != NULL); 88 assert(*bandit != NULL); 89 90 banditptr = *bandit; 91 vtable = banditptr->vtable; 92 assert(vtable != NULL); 93 94 /* call bandit specific data destructor */ 95 if( vtable->banditfree != NULL ) 96 { 97 SCIP_CALL( vtable->banditfree(blkmem, banditptr) ); 98 } 99 100 /* free random number generator */ 101 SCIPrandomFree(&banditptr->rng, blkmem); 102 103 BMSfreeBlockMemory(blkmem, bandit); 104 105 return SCIP_OKAY; 106 } 107 108 /** reset the bandit algorithm */ 109 SCIP_RETCODE SCIPbanditReset( 110 BMS_BUFMEM* bufmem, /**< buffer memory */ 111 SCIP_BANDIT* bandit, /**< pointer to bandit algorithm data structure */ 112 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */ 113 unsigned int seed /**< initial seed for random number generation */ 114 ) 115 { 116 SCIP_BANDITVTABLE* vtable; 117 118 assert(bandit != NULL); 119 assert(bufmem != NULL); 120 121 vtable = bandit->vtable; 122 assert(vtable != NULL); 123 assert(vtable->banditreset != NULL); 124 125 /* test if the priorities are nonnegative */ 126 if( priorities != NULL ) 127 { 128 int i; 129 130 assert(SCIPbanditGetNActions(bandit) > 0); 131 132 for( i = 0; i < SCIPbanditGetNActions(bandit); ++i ) 133 { 134 if( priorities[i] < 0 ) 135 { 136 SCIPerrorMessage("Negative priority for action %d\n", i); 137 138 return SCIP_INVALIDDATA; 139 } 140 } 141 } 142 143 /* reset the random seed of the bandit algorithm */ 144 SCIPrandomSetSeed(bandit->rng, seed); 145 146 /* call the reset callback of the bandit algorithm */ 147 SCIP_CALL( vtable->banditreset(bufmem, bandit, priorities) ); 148 149 return SCIP_OKAY; 150 } 151 152 /** select the next action */ 153 SCIP_RETCODE SCIPbanditSelect( 154 SCIP_BANDIT* bandit, /**< bandit algorithm data structure */ 155 int* action /**< pointer to store the selected action */ 156 ) 157 { 158 assert(bandit != NULL); 159 assert(action != NULL); 160 161 *action = -1; 162 163 assert(bandit->vtable->banditselect != NULL); 164 165 SCIP_CALL( bandit->vtable->banditselect(bandit, action) ); 166 167 assert(*action >= 0); 168 assert(*action < SCIPbanditGetNActions(bandit)); 169 170 return SCIP_OKAY; 171 } 172 173 /** update the score of the selected action */ 174 SCIP_RETCODE SCIPbanditUpdate( 175 SCIP_BANDIT* bandit, /**< bandit algorithm data structure */ 176 int action, /**< index of action for which the score should be updated */ 177 SCIP_Real score /**< observed gain of the i'th action */ 178 ) 179 { 180 assert(bandit != NULL); 181 assert(0 <= action && action < SCIPbanditGetNActions(bandit)); 182 assert(bandit->vtable->banditupdate != NULL); 183 184 SCIP_CALL( bandit->vtable->banditupdate(bandit, action, score) ); 185 186 return SCIP_OKAY; 187 } 188 189 /** get data of this bandit algorithm */ 190 SCIP_BANDITDATA* SCIPbanditGetData( 191 SCIP_BANDIT* bandit /**< bandit algorithm data structure */ 192 ) 193 { 194 assert(bandit != NULL); 195 196 return bandit->data; 197 } 198 199 /** set the data of this bandit algorithm */ 200 void SCIPbanditSetData( 201 SCIP_BANDIT* bandit, /**< bandit algorithm data structure */ 202 SCIP_BANDITDATA* banditdata /**< bandit algorihm specific data, or NULL */ 203 ) 204 { 205 assert(bandit != NULL); 206 207 bandit->data = banditdata; 208 } 209 210 /** internal method to create a bandit VTable */ 211 static 212 SCIP_RETCODE doBanditvtableCreate( 213 SCIP_BANDITVTABLE** banditvtable, /**< pointer to virtual table for bandit algorithm */ 214 const char* name, /**< a name for the algorithm represented by this vtable */ 215 SCIP_DECL_BANDITFREE ((*banditfree)), /**< callback to free bandit specific data structures */ 216 SCIP_DECL_BANDITSELECT((*banditselect)), /**< selection callback for bandit selector */ 217 SCIP_DECL_BANDITUPDATE((*banditupdate)), /**< update callback for bandit algorithms */ 218 SCIP_DECL_BANDITRESET ((*banditreset)) /**< update callback for bandit algorithms */ 219 ) 220 { 221 SCIP_BANDITVTABLE* banditvtableptr; 222 223 assert(banditvtable != NULL); 224 assert(name != NULL); 225 assert(banditfree != NULL); 226 assert(banditselect != NULL); 227 assert(banditupdate != NULL); 228 assert(banditreset != NULL); 229 230 /* allocate memory for this virtual function table */ 231 SCIP_ALLOC( BMSallocMemory(banditvtable) ); 232 BMSclearMemory(*banditvtable); 233 234 SCIP_ALLOC( BMSduplicateMemoryArray(&(*banditvtable)->name, name, strlen(name)+1) ); 235 banditvtableptr = *banditvtable; 236 banditvtableptr->banditfree = banditfree; 237 banditvtableptr->banditselect = banditselect; 238 banditvtableptr->banditupdate = banditupdate; 239 banditvtableptr->banditreset = banditreset; 240 241 return SCIP_OKAY; 242 } 243 244 /** create a bandit VTable for bandit algorithm callback functions */ 245 SCIP_RETCODE SCIPbanditvtableCreate( 246 SCIP_BANDITVTABLE** banditvtable, /**< pointer to virtual table for bandit algorithm */ 247 const char* name, /**< a name for the algorithm represented by this vtable */ 248 SCIP_DECL_BANDITFREE ((*banditfree)), /**< callback to free bandit specific data structures */ 249 SCIP_DECL_BANDITSELECT((*banditselect)), /**< selection callback for bandit selector */ 250 SCIP_DECL_BANDITUPDATE((*banditupdate)), /**< update callback for bandit algorithms */ 251 SCIP_DECL_BANDITRESET ((*banditreset)) /**< update callback for bandit algorithms */ 252 ) 253 { 254 assert(banditvtable != NULL); 255 assert(name != NULL); 256 assert(banditfree != NULL); 257 assert(banditselect != NULL); 258 assert(banditupdate != NULL); 259 assert(banditreset != NULL); 260 261 SCIP_CALL_FINALLY( doBanditvtableCreate(banditvtable, name, banditfree, banditselect, banditupdate, banditreset), 262 SCIPbanditvtableFree(banditvtable) ); 263 264 return SCIP_OKAY; 265 } 266 267 268 /** free a bandit virtual table for bandit algorithm callback functions */ 269 void SCIPbanditvtableFree( 270 SCIP_BANDITVTABLE** banditvtable /**< pointer to virtual table for bandit algorithm */ 271 ) 272 { 273 assert(banditvtable != NULL); 274 if( *banditvtable == NULL ) 275 return; 276 277 BMSfreeMemoryArrayNull(&(*banditvtable)->name); 278 BMSfreeMemory(banditvtable); 279 } 280 281 /** return the name of this bandit virtual function table */ 282 const char* SCIPbanditvtableGetName( 283 SCIP_BANDITVTABLE* banditvtable /**< virtual table for bandit algorithm */ 284 ) 285 { 286 assert(banditvtable != NULL); 287 288 return banditvtable->name; 289 } 290 291 292 /** return the random number generator of a bandit algorithm */ 293 SCIP_RANDNUMGEN* SCIPbanditGetRandnumgen( 294 SCIP_BANDIT* bandit /**< bandit algorithm data structure */ 295 ) 296 { 297 assert(bandit != NULL); 298 299 return bandit->rng; 300 } 301 302 /** return number of actions of this bandit algorithm */ 303 int SCIPbanditGetNActions( 304 SCIP_BANDIT* bandit /**< bandit algorithm data structure */ 305 ) 306 { 307 assert(bandit != NULL); 308 309 return bandit->nactions; 310 } 311