1    	/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2    	/*                                                                           */
3    	/*                  This file is part of the program and library             */
4    	/*         SCIP --- Solving Constraint Integer Programs                      */
5    	/*                                                                           */
6    	/*  Copyright 2002-2022 Zuse Institute Berlin                                */
7    	/*                                                                           */
8    	/*  Licensed under the Apache License, Version 2.0 (the "License");          */
9    	/*  you may not use this file except in compliance with the License.         */
10   	/*  You may obtain a copy of the License at                                  */
11   	/*                                                                           */
12   	/*      http://www.apache.org/licenses/LICENSE-2.0                           */
13   	/*                                                                           */
14   	/*  Unless required by applicable law or agreed to in writing, software      */
15   	/*  distributed under the License is distributed on an "AS IS" BASIS,        */
16   	/*  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
17   	/*  See the License for the specific language governing permissions and      */
18   	/*  limitations under the License.                                           */
19   	/*                                                                           */
20   	/*  You should have received a copy of the Apache-2.0 license                */
21   	/*  along with SCIP; see the file LICENSE. If not visit scipopt.org.         */
22   	/*                                                                           */
23   	/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
24   	
25   	/**@file   bandit_exp3ix.c
26   	 * @ingroup OTHER_CFILES
27   	 * @brief  methods for Exp.3-IX bandit selection
28   	 * @author Antonia Chmiela
29   	 */
30   	
31   	/*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
32   	
33   	#include "scip/bandit.h"
34   	#include "scip/bandit_exp3ix.h"
35   	#include "scip/pub_bandit.h"
36   	#include "scip/pub_message.h"
37   	#include "scip/pub_misc.h"
38   	#include "scip/scip_bandit.h"
39   	#include "scip/scip_mem.h"
40   	#include "scip/scip_randnumgen.h"
41   	
42   	#define BANDIT_NAME "exp3ix"
43   	
44   	/*
45   	 * Data structures
46   	 */
47   	
48   	/** implementation specific data of Exp.3 bandit algorithm */
49   	struct SCIP_BanditData
50   	{
51   	   SCIP_Real*            weights;            /**< exponential weight for each arm */
52   	   SCIP_Real             weightsum;          /**< the sum of all weights */
53   	   int                   iter;               /**< current iteration counter to compute parameters gamma_t and eta_t */
54   	};
55   	
56   	/*
57   	 * Local methods
58   	 */
59   	
60   	/*
61   	 * Callback methods of bandit algorithm
62   	 */
63   	
64   	/** callback to free bandit specific data structures */
65   	SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3IX)
66   	{  /*lint --e{715}*/
67   	   SCIP_BANDITDATA* banditdata;
68   	   int nactions;
69   	   assert(bandit != NULL);
70   	
71   	   banditdata = SCIPbanditGetData(bandit);
72   	   assert(banditdata != NULL);
73   	   nactions = SCIPbanditGetNActions(bandit);
74   	
75   	   BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
76   	
77   	   BMSfreeBlockMemory(blkmem, &banditdata);
78   	
79   	   SCIPbanditSetData(bandit, NULL);
80   	
81   	   return SCIP_OKAY;
82   	}
83   	
84   	/** selection callback for bandit selector */
85   	SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3IX)
86   	{  /*lint --e{715}*/
87   	   SCIP_BANDITDATA* banditdata;
88   	   SCIP_RANDNUMGEN* rng;
89   	   SCIP_Real* weights;
90   	   SCIP_Real weightsum;
91   	   int i;
92   	   int nactions;
93   	   SCIP_Real psum;
94   	   SCIP_Real randnr;
95   	
96   	   assert(bandit != NULL);
97   	   assert(selection != NULL);
98   	
99   	   banditdata = SCIPbanditGetData(bandit);
100  	   assert(banditdata != NULL);
101  	   rng = SCIPbanditGetRandnumgen(bandit);
102  	   assert(rng != NULL);
103  	   nactions = SCIPbanditGetNActions(bandit);
104  	
105  	   /* initialize some local variables to speed up probability computations */
106  	   weightsum = banditdata->weightsum;
107  	   weights = banditdata->weights;
108  	
109  	   /* draw a random number between 0 and 1 */
110  	   randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
111  	
112  	   /* loop over probability distribution until rand is reached
113  	    * the loop terminates without looking at the last action,
114  	    * which is then selected automatically if the target probability
115  	    * is not reached earlier
116  	    */
117  	   psum = 0.0;
118  	   for( i = 0; i < nactions - 1; ++i )
119  	   {
120  	      SCIP_Real prob;
121  	
122  	      /* compute the probability for arm i */
123  	      prob = weights[i] / weightsum;
124  	      psum += prob;
125  	
126  	      /* break and select element if target probability is reached */
127  	      if( randnr <= psum )
128  	         break;
129  	   }
130  	
131  	   /* select element i, which is the last action in case that the break statement hasn't been reached */
132  	   *selection = i;
133  	
134  	   return SCIP_OKAY;
135  	}
136  	
137  	/** compute gamma_t */
138  	static
139  	SCIP_Real SCIPcomputeGamma(
140  	   int                   nactions,           /**< the positive number of actions for this bandit algorithm */
141  	   int                   t                   /**< current iteration */
142  	   )
143  	{
144  	   return SQRT(log((SCIP_Real)nactions) / (4.0 * (SCIP_Real)t * (SCIP_Real)nactions) );
145  	}
146  	
147  	/** update callback for bandit algorithm */
148  	SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3IX)
149  	{  /*lint --e{715}*/
150  	   SCIP_BANDITDATA* banditdata;
151  	   SCIP_Real etaparam;
152  	   SCIP_Real lossestim;
153  	   SCIP_Real prob;
154  	   SCIP_Real weightsum;
155  	   SCIP_Real newweightsum;
156  	   SCIP_Real* weights;
157  	   SCIP_Real gammaparam;
158  	   int nactions;
159  	
160  	   assert(bandit != NULL);
161  	
162  	   banditdata = SCIPbanditGetData(bandit);
163  	   assert(banditdata != NULL);
164  	   nactions = SCIPbanditGetNActions(bandit);
165  	
166  	   assert(selection >= 0);
167  	   assert(selection < nactions);
168  	
169  	   weights = banditdata->weights;
170  	   weightsum = banditdata->weightsum;
171  	   newweightsum = weightsum;
172  	   gammaparam = SCIPcomputeGamma(nactions, banditdata->iter);
173  	   etaparam = 2.0 * gammaparam;
174  	
175  	   /* probability of selection */
176  	   prob = weights[selection] / weightsum;
177  	
178  	   /* estimated loss */
179  	   lossestim = (1.0 - score) / (prob + gammaparam);
180  	   assert(lossestim >= 0);
181  	
182  	   /* update the observation for the current arm */
183  	   newweightsum -= weights[selection];
184  	   weights[selection] *= exp(-etaparam * lossestim);
185  	   newweightsum += weights[selection];
186  	
187  	   banditdata->weightsum = newweightsum;
188  	
189  	   /* increase iteration counter */
190  	   banditdata->iter += 1;
191  	
192  	   return SCIP_OKAY;
193  	}
194  	
195  	/** reset callback for bandit algorithm */
196  	SCIP_DECL_BANDITRESET(SCIPbanditResetExp3IX)
197  	{  /*lint --e{715}*/
198  	   SCIP_BANDITDATA* banditdata;
199  	   SCIP_Real* weights;
200  	   int nactions;
201  	   int i;
202  	
203  	   assert(bandit != NULL);
204  	
205  	   banditdata = SCIPbanditGetData(bandit);
206  	   assert(banditdata != NULL);
207  	   nactions = SCIPbanditGetNActions(bandit);
208  	   weights = banditdata->weights;
209  	
210  	   assert(nactions > 0);
211  	
212  	   /* initialize all weights with 1.0 */
213  	   for( i = 0; i < nactions; ++i )
214  	      weights[i] = 1.0;
215  	
216  	   banditdata->weightsum = (SCIP_Real)nactions;
217  	
218  	   /* set iteration counter to 1 */
219  	   banditdata->iter = 1;
220  	
221  	   return SCIP_OKAY;
222  	}
223  	
224  	
225  	/*
226  	 * bandit algorithm specific interface methods
227  	 */
228  	
229  	/** direct bandit creation method for the core where no SCIP pointer is available */
230  	SCIP_RETCODE SCIPbanditCreateExp3IX(
231  	   BMS_BLKMEM*           blkmem,             /**< block memory data structure */
232  	   BMS_BUFMEM*           bufmem,             /**< buffer memory */
233  	   SCIP_BANDITVTABLE*    vtable,             /**< virtual function table for callback functions of Exp.3-IX */
234  	   SCIP_BANDIT**         exp3ix,             /**< pointer to store bandit algorithm */
235  	   SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
236  	   int                   nactions,           /**< the positive number of actions for this bandit algorithm */
237  	   unsigned int          initseed            /**< initial random seed */
238  	   )
239  	{
240  	   SCIP_BANDITDATA* banditdata;
241  	
242  	   SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
243  	   assert(banditdata != NULL);
244  	
245  	   banditdata->iter = 1;
246  	
247  	   SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
248  	
249  	   SCIP_CALL( SCIPbanditCreate(exp3ix, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
250  	
251  	   return SCIP_OKAY;
252  	}
253  	
254  	/** creates and resets an Exp.3-IX bandit algorithm using \p scip pointer */
255  	SCIP_RETCODE SCIPcreateBanditExp3IX(
256  	   SCIP*                 scip,               /**< SCIP data structure */
257  	   SCIP_BANDIT**         exp3ix,             /**< pointer to store bandit algorithm */
258  	   SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
259  	   int                   nactions,           /**< the positive number of actions for this bandit algorithm */
260  	   unsigned int          initseed            /**< initial seed for random number generation */
261  	   )
262  	{
263  	   SCIP_BANDITVTABLE* vtable;
264  	
265  	   vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
266  	   if( vtable == NULL )
267  	   {
268  	      SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
269  	      return SCIP_INVALIDDATA;
270  	   }
271  	
272  	   SCIP_CALL( SCIPbanditCreateExp3IX(SCIPblkmem(scip), SCIPbuffer(scip), vtable, exp3ix,
273  	         priorities, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
274  	
275  	   return SCIP_OKAY;
276  	}
277  	
278  	/** returns probability to play an action */
279  	SCIP_Real SCIPgetProbabilityExp3IX(
280  	   SCIP_BANDIT*          exp3ix,             /**< bandit algorithm */
281  	   int                   action              /**< index of the requested action */
282  	   )
283  	{
284  	   SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3ix);
285  	
286  	   assert(banditdata->weightsum > 0.0);
287  	   assert(SCIPbanditGetNActions(exp3ix) > 0);
288  	
289  	   return banditdata->weights[action] / banditdata->weightsum;
290  	}
291  	
292  	/** include virtual function table for Exp.3-IX bandit algorithms */
293  	SCIP_RETCODE SCIPincludeBanditvtableExp3IX(
294  	   SCIP*                 scip                /**< SCIP data structure */
295  	   )
296  	{
297  	   SCIP_BANDITVTABLE* vtable;
298  	
299  	   SCIP_CALL( SCIPincludeBanditvtable(scip, &vtable, BANDIT_NAME,
300  	         SCIPbanditFreeExp3IX, SCIPbanditSelectExp3IX, SCIPbanditUpdateExp3IX, SCIPbanditResetExp3IX) );
301  	   assert(vtable != NULL);
302  	
303  	   return SCIP_OKAY;
304  	}
305