1    	/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2    	/*                                                                           */
3    	/*                  This file is part of the program and library             */
4    	/*         SCIP --- Solving Constraint Integer Programs                      */
5    	/*                                                                           */
6    	/*  Copyright (c) 2002-2023 Zuse Institute Berlin (ZIB)                      */
7    	/*                                                                           */
8    	/*  Licensed under the Apache License, Version 2.0 (the "License");          */
9    	/*  you may not use this file except in compliance with the License.         */
10   	/*  You may obtain a copy of the License at                                  */
11   	/*                                                                           */
12   	/*      http://www.apache.org/licenses/LICENSE-2.0                           */
13   	/*                                                                           */
14   	/*  Unless required by applicable law or agreed to in writing, software      */
15   	/*  distributed under the License is distributed on an "AS IS" BASIS,        */
16   	/*  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
17   	/*  See the License for the specific language governing permissions and      */
18   	/*  limitations under the License.                                           */
19   	/*                                                                           */
20   	/*  You should have received a copy of the Apache-2.0 license                */
21   	/*  along with SCIP; see the file LICENSE. If not visit scipopt.org.         */
22   	/*                                                                           */
23   	/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
24   	
25   	/**@file   bandit_exp3.c
26   	 * @ingroup OTHER_CFILES
27   	 * @brief  methods for Exp.3 bandit selection
28   	 * @author Gregor Hendel
29   	 */
30   	
31   	/*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
32   	
33   	#include "scip/bandit.h"
34   	#include "scip/bandit_exp3.h"
35   	#include "scip/pub_bandit.h"
36   	#include "scip/pub_message.h"
37   	#include "scip/pub_misc.h"
38   	#include "scip/scip_bandit.h"
39   	#include "scip/scip_mem.h"
40   	#include "scip/scip_randnumgen.h"
41   	
42   	#define BANDIT_NAME "exp3"
43   	#define NUMTOL 1e-6
44   	
45   	/*
46   	 * Data structures
47   	 */
48   	
49   	/** implementation specific data of Exp.3 bandit algorithm */
50   	struct SCIP_BanditData
51   	{
52   	   SCIP_Real*            weights;            /**< exponential weight for each arm */
53   	   SCIP_Real             weightsum;          /**< the sum of all weights */
54   	   SCIP_Real             gamma;              /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
55   	   SCIP_Real             beta;               /**< gain offset between 0 and 1 at every observation */
56   	};
57   	
58   	/*
59   	 * Local methods
60   	 */
61   	
62   	/*
63   	 * Callback methods of bandit algorithm
64   	 */
65   	
66   	/** callback to free bandit specific data structures */
67   	SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)
68   	{  /*lint --e{715}*/
69   	   SCIP_BANDITDATA* banditdata;
70   	   int nactions;
71   	   assert(bandit != NULL);
72   	
73   	   banditdata = SCIPbanditGetData(bandit);
74   	   assert(banditdata != NULL);
75   	   nactions = SCIPbanditGetNActions(bandit);
76   	
77   	   BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
78   	
79   	   BMSfreeBlockMemory(blkmem, &banditdata);
80   	
81   	   SCIPbanditSetData(bandit, NULL);
82   	
83   	   return SCIP_OKAY;
84   	}
85   	
86   	/** selection callback for bandit selector */
87   	SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)
88   	{  /*lint --e{715}*/
89   	   SCIP_BANDITDATA* banditdata;
90   	   SCIP_RANDNUMGEN* rng;
91   	   SCIP_Real randnr;
92   	   SCIP_Real psum;
93   	   SCIP_Real gammaoverk;
94   	   SCIP_Real oneminusgamma;
95   	   SCIP_Real* weights;
96   	   SCIP_Real weightsum;
97   	   int i;
98   	   int nactions;
99   	
100  	   assert(bandit != NULL);
101  	   assert(selection != NULL);
102  	
103  	   banditdata = SCIPbanditGetData(bandit);
104  	   assert(banditdata != NULL);
105  	   rng = SCIPbanditGetRandnumgen(bandit);
106  	   assert(rng != NULL);
107  	   nactions = SCIPbanditGetNActions(bandit);
108  	
109  	   /* draw a random number between 0 and 1 */
110  	   randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
111  	
112  	   /* initialize some local variables to speed up probability computations */
113  	   oneminusgamma = 1 - banditdata->gamma;
114  	   gammaoverk = banditdata->gamma / (SCIP_Real)nactions;
115  	   weightsum = banditdata->weightsum;
116  	   weights = banditdata->weights;
117  	   psum = 0.0;
118  	
119  	   /* loop over probability distribution until rand is reached
120  	    * the loop terminates without looking at the last action,
121  	    * which is then selected automatically if the target probability
122  	    * is not reached earlier
123  	    */
124  	   for( i = 0; i < nactions - 1; ++i )
125  	   {
126  	      SCIP_Real prob;
127  	
128  	      /* compute the probability for arm i as convex kombination of a uniform distribution and a weighted distribution */
129  	      prob = oneminusgamma * weights[i] / weightsum + gammaoverk;
130  	      psum += prob;
131  	
132  	      /* break and select element if target probability is reached */
133  	      if( randnr <= psum )
134  	         break;
135  	   }
136  	
137  	   /* select element i, which is the last action in case that the break statement hasn't been reached */
138  	   *selection = i;
139  	
140  	   return SCIP_OKAY;
141  	}
142  	
143  	/** update callback for bandit algorithm */
144  	SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)
145  	{  /*lint --e{715}*/
146  	   SCIP_BANDITDATA* banditdata;
147  	   SCIP_Real eta;
148  	   SCIP_Real gainestim;
149  	   SCIP_Real beta;
150  	   SCIP_Real weightsum;
151  	   SCIP_Real newweightsum;
152  	   SCIP_Real* weights;
153  	   SCIP_Real oneminusgamma;
154  	   SCIP_Real gammaoverk;
155  	   int nactions;
156  	
157  	   assert(bandit != NULL);
158  	
159  	   banditdata = SCIPbanditGetData(bandit);
160  	   assert(banditdata != NULL);
161  	   nactions = SCIPbanditGetNActions(bandit);
162  	
163  	   assert(selection >= 0);
164  	   assert(selection < nactions);
165  	
166  	   /* the learning rate eta */
167  	   eta = 1.0 / (SCIP_Real)nactions;
168  	
169  	   beta = banditdata->beta;
170  	   oneminusgamma = 1.0 - banditdata->gamma;
171  	   gammaoverk = banditdata->gamma * eta;
172  	   weights = banditdata->weights;
173  	   weightsum = banditdata->weightsum;
174  	   newweightsum = weightsum;
175  	
176  	   /* if beta is zero, only the observation for the current arm needs an update */
177  	   if( EPSZ(beta, NUMTOL) )
178  	   {
179  	      SCIP_Real probai;
180  	      probai = oneminusgamma * weights[selection] / weightsum + gammaoverk;
181  	
182  	      assert(probai > 0.0);
183  	
184  	      gainestim = score / probai;
185  	      newweightsum -= weights[selection];
186  	      weights[selection] *= exp(eta * gainestim);
187  	      newweightsum += weights[selection];
188  	   }
189  	   else
190  	   {
191  	      int j;
192  	      newweightsum = 0.0;
193  	
194  	      /* loop over all items and update their weights based on the influence of the beta parameter */
195  	      for( j = 0; j < nactions; ++j )
196  	      {
197  	         SCIP_Real probaj;
198  	         probaj = oneminusgamma * weights[j] / weightsum + gammaoverk;
199  	
200  	         assert(probaj > 0.0);
201  	
202  	         /* consider the score only for the chosen arm i, use constant beta offset otherwise */
203  	         if( j == selection )
204  	            gainestim = (score + beta) / probaj;
205  	         else
206  	            gainestim = beta / probaj;
207  	
208  	         weights[j] *= exp(eta * gainestim);
209  	         newweightsum += weights[j];
210  	      }
211  	   }
212  	
213  	   banditdata->weightsum = newweightsum;
214  	
215  	   return SCIP_OKAY;
216  	}
217  	
218  	/** reset callback for bandit algorithm */
219  	SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)
220  	{  /*lint --e{715}*/
221  	   SCIP_BANDITDATA* banditdata;
222  	   SCIP_Real* weights;
223  	   int nactions;
224  	   int i;
225  	
226  	   assert(bandit != NULL);
227  	
228  	   banditdata = SCIPbanditGetData(bandit);
229  	   assert(banditdata != NULL);
230  	   nactions = SCIPbanditGetNActions(bandit);
231  	   weights = banditdata->weights;
232  	
233  	   assert(nactions > 0);
234  	
235  	   banditdata->weightsum = (1.0 + NUMTOL) * (SCIP_Real)nactions;
236  	
237  	   /* in case of priorities, weights are normalized to sum up to nactions */
238  	   if( priorities != NULL )
239  	   {
240  	      SCIP_Real normalization;
241  	      SCIP_Real priosum;
242  	      priosum = 0.0;
243  	
244  	      /* compute sum of priorities */
245  	      for( i = 0; i < nactions; ++i )
246  	      {
247  	         assert(priorities[i] >= 0);
248  	         priosum += priorities[i];
249  	      }
250  	
251  	      /* if there are positive priorities, normalize the weights */
252  	      if( priosum > 0.0 )
253  	      {
254  	         normalization = nactions / priosum;
255  	         for( i = 0; i < nactions; ++i )
256  	            weights[i] = (priorities[i] * normalization) + NUMTOL;
257  	      }
258  	      else
259  	      {
260  	         /* use uniform distribution in case of all priorities being 0.0 */
261  	         for( i = 0; i < nactions; ++i )
262  	            weights[i] = 1.0 + NUMTOL;
263  	      }
264  	   }
265  	   else
266  	   {
267  	      /* use uniform distribution in case of unspecified priorities */
268  	      for( i = 0; i < nactions; ++i )
269  	         weights[i] = 1.0 + NUMTOL;
270  	   }
271  	
272  	   return SCIP_OKAY;
273  	}
274  	
275  	
276  	/*
277  	 * bandit algorithm specific interface methods
278  	 */
279  	
280  	/** direct bandit creation method for the core where no SCIP pointer is available */
281  	SCIP_RETCODE SCIPbanditCreateExp3(
282  	   BMS_BLKMEM*           blkmem,             /**< block memory data structure */
283  	   BMS_BUFMEM*           bufmem,             /**< buffer memory */
284  	   SCIP_BANDITVTABLE*    vtable,             /**< virtual function table for callback functions of Exp.3 */
285  	   SCIP_BANDIT**         exp3,               /**< pointer to store bandit algorithm */
286  	   SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
287  	   SCIP_Real             gammaparam,         /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
288  	   SCIP_Real             beta,               /**< gain offset between 0 and 1 at every observation */
289  	   int                   nactions,           /**< the positive number of actions for this bandit algorithm */
290  	   unsigned int          initseed            /**< initial random seed */
291  	   )
292  	{
293  	   SCIP_BANDITDATA* banditdata;
294  	
295  	   SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
296  	   assert(banditdata != NULL);
297  	
298  	   banditdata->gamma = gammaparam;
299  	   banditdata->beta = beta;
300  	   assert(gammaparam >= 0 && gammaparam <= 1);
301  	   assert(beta >= 0 && beta <= 1);
302  	
303  	   SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
304  	
305  	   SCIP_CALL( SCIPbanditCreate(exp3, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
306  	
307  	   return SCIP_OKAY;
308  	}
309  	
310  	/** creates and resets an Exp.3 bandit algorithm using \p scip pointer */
311  	SCIP_RETCODE SCIPcreateBanditExp3(
312  	   SCIP*                 scip,               /**< SCIP data structure */
313  	   SCIP_BANDIT**         exp3,               /**< pointer to store bandit algorithm */
314  	   SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
315  	   SCIP_Real             gammaparam,         /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
316  	   SCIP_Real             beta,               /**< gain offset between 0 and 1 at every observation */
317  	   int                   nactions,           /**< the positive number of actions for this bandit algorithm */
318  	   unsigned int          initseed            /**< initial seed for random number generation */
319  	   )
320  	{
321  	   SCIP_BANDITVTABLE* vtable;
322  	
323  	   vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
324  	   if( vtable == NULL )
325  	   {
326  	      SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
327  	      return SCIP_INVALIDDATA;
328  	   }
329  	
330  	   SCIP_CALL( SCIPbanditCreateExp3(SCIPblkmem(scip), SCIPbuffer(scip), vtable, exp3,
331  	         priorities, gammaparam, beta, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
332  	
333  	   return SCIP_OKAY;
334  	}
335  	
336  	/** set gamma parameter of Exp.3 bandit algorithm to increase weight of uniform distribution */
337  	void SCIPsetGammaExp3(
338  	   SCIP_BANDIT*          exp3,               /**< bandit algorithm */
339  	   SCIP_Real             gammaparam          /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
340  	   )
341  	{
342  	   SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
343  	
344  	   assert(gammaparam >= 0 && gammaparam <= 1);
345  	
346  	   banditdata->gamma = gammaparam;
347  	}
348  	
349  	/** set beta parameter of Exp.3 bandit algorithm to increase gain offset for actions that were not played */
350  	void SCIPsetBetaExp3(
351  	   SCIP_BANDIT*          exp3,               /**< bandit algorithm */
352  	   SCIP_Real             beta                /**< gain offset between 0 and 1 at every observation */
353  	   )
354  	{
355  	   SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
356  	
357  	   assert(beta >= 0 && beta <= 1);
358  	
359  	   banditdata->beta = beta;
360  	}
361  	
362  	/** returns probability to play an action */
363  	SCIP_Real SCIPgetProbabilityExp3(
364  	   SCIP_BANDIT*          exp3,               /**< bandit algorithm */
365  	   int                   action              /**< index of the requested action */
366  	   )
367  	{
368  	   SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
369  	
370  	   assert(banditdata->weightsum > 0.0);
371  	   assert(SCIPbanditGetNActions(exp3) > 0);
372  	
373  	   return (1.0 - banditdata->gamma) * banditdata->weights[action] / banditdata->weightsum + banditdata->gamma / (SCIP_Real)SCIPbanditGetNActions(exp3);
374  	}
375  	
376  	/** include virtual function table for Exp.3 bandit algorithms */
377  	SCIP_RETCODE SCIPincludeBanditvtableExp3(
378  	   SCIP*                 scip                /**< SCIP data structure */
379  	   )
380  	{
381  	   SCIP_BANDITVTABLE* vtable;
382  	
383  	   SCIP_CALL( SCIPincludeBanditvtable(scip, &vtable, BANDIT_NAME,
384  	         SCIPbanditFreeExp3, SCIPbanditSelectExp3, SCIPbanditUpdateExp3, SCIPbanditResetExp3) );
385  	   assert(vtable != NULL);
386  	
387  	   return SCIP_OKAY;
388  	}
389