1    	/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2    	/*                                                                           */
3    	/*                  This file is part of the program and library             */
4    	/*         SCIP --- Solving Constraint Integer Programs                      */
5    	/*                                                                           */
6    	/*  Copyright (c) 2002-2023 Zuse Institute Berlin (ZIB)                      */
7    	/*                                                                           */
8    	/*  Licensed under the Apache License, Version 2.0 (the "License");          */
9    	/*  you may not use this file except in compliance with the License.         */
10   	/*  You may obtain a copy of the License at                                  */
11   	/*                                                                           */
12   	/*      http://www.apache.org/licenses/LICENSE-2.0                           */
13   	/*                                                                           */
14   	/*  Unless required by applicable law or agreed to in writing, software      */
15   	/*  distributed under the License is distributed on an "AS IS" BASIS,        */
16   	/*  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
17   	/*  See the License for the specific language governing permissions and      */
18   	/*  limitations under the License.                                           */
19   	/*                                                                           */
20   	/*  You should have received a copy of the Apache-2.0 license                */
21   	/*  along with SCIP; see the file LICENSE. If not visit scipopt.org.         */
22   	/*                                                                           */
23   	/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
24   	
25   	/**@file   bandit_epsgreedy.c
26   	 * @ingroup OTHER_CFILES
27   	 * @brief  implementation of (a variant of) epsilon greedy bandit algorithm
28   	 * @author Gregor Hendel
29   	 */
30   	
31   	/*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
32   	
33   	#include "scip/bandit.h"
34   	#include "scip/bandit_epsgreedy.h"
35   	#include "scip/pub_bandit.h"
36   	#include "scip/pub_message.h"
37   	#include "scip/pub_misc.h"
38   	#include "scip/scip_bandit.h"
39   	#include "scip/scip_mem.h"
40   	#include "scip/scip_randnumgen.h"
41   	
42   	#define BANDIT_NAME           "eps-greedy"
43   	#define EPSGREEDY_SMALL       1e-6
44   	
45   	/*
46   	 * Data structures
47   	 */
48   	
49   	/** private data structure of epsilon greedy bandit algorithm */
50   	struct SCIP_BanditData
51   	{
52   	   SCIP_Real*            weights;            /**< weights for every action */
53   	   SCIP_Real*            priorities;         /**< saved priorities for tie breaking */
54   	   int*                  sels;               /**< individual number of selections per action */
55   	   SCIP_Real             eps;                /**< epsilon parameter (between 0 and 1) to control epsilon greedy */
56   	   SCIP_Bool             usemodification;    /**< TRUE if modified eps greedy should be used */
57   	   SCIP_Real             decayfactor;        /**< the factor to reduce the weight of older observations if exponential decay is enabled */
58   	   int                   avglim;             /**< nonnegative limit on observation number before the exponential decay starts,
59   	                                               *  only relevant if exponential decay is enabled
60   	                                               */
61   	   int                   nselections;        /**< counter for the number of selection calls */
62   	   SCIP_Bool             preferrecent;       /**< should the weights be updated in an exponentially decaying way? */
63   	};
64   	
65   	/*
66   	 * Callback methods of bandit algorithm virtual function table
67   	 */
68   	
69   	/** callback to free bandit specific data structures */
70   	SCIP_DECL_BANDITFREE(SCIPbanditFreeEpsgreedy)
71   	{  /*lint --e{715}*/
72   	   SCIP_BANDITDATA* banditdata;
73   	   int nactions;
74   	
75   	   assert(bandit != NULL);
76   	
77   	   banditdata = SCIPbanditGetData(bandit);
78   	   assert(banditdata != NULL);
79   	   assert(banditdata->weights != NULL);
80   	   nactions = SCIPbanditGetNActions(bandit);
81   	
82   	   BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
83   	   BMSfreeBlockMemoryArray(blkmem, &banditdata->priorities, nactions);
84   	   BMSfreeBlockMemoryArray(blkmem, &banditdata->sels, nactions);
85   	   BMSfreeBlockMemory(blkmem, &banditdata);
86   	
87   	   SCIPbanditSetData(bandit, NULL);
88   	
89   	   return SCIP_OKAY;
90   	}
91   	
92   	/** selection callback for bandit algorithm */
93   	SCIP_DECL_BANDITSELECT(SCIPbanditSelectEpsgreedy)
94   	{  /*lint --e{715}*/
95   	   SCIP_BANDITDATA* banditdata;
96   	   SCIP_Real randnr;
97   	   SCIP_Real curreps;
98   	   SCIP_RANDNUMGEN* rng;
99   	   int nactions;
100  	   assert(bandit != NULL);
101  	   assert(selection != NULL);
102  	
103  	   banditdata = SCIPbanditGetData(bandit);
104  	   assert(banditdata != NULL);
105  	   rng = SCIPbanditGetRandnumgen(bandit);
106  	   assert(rng != NULL);
107  	
108  	   nactions = SCIPbanditGetNActions(bandit);
109  	
110  	   /* roll the dice to check if the best element should be picked, or an element at random */
111  	   randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
112  	
113  	   /* make epsilon decrease with an increasing number of selections */
114  	   banditdata->nselections++;
115  	   curreps = banditdata->eps * sqrt((SCIP_Real)nactions/(SCIP_Real)banditdata->nselections);
116  	
117  	   /* select the best action seen so far */
118  	   if( randnr >= curreps )
119  	   {
120  	      SCIP_Real* weights = banditdata->weights;
121  	      SCIP_Real* priorities = banditdata->priorities;
122  	      int j;
123  	      SCIP_Real maxweight;
124  	
125  	      assert(weights != NULL);
126  	      assert(priorities != NULL);
127  	
128  	      /* pick the element with the largest reward */
129  	      maxweight = weights[0];
130  	      *selection = 0;
131  	
132  	      /* determine reward for every element */
133  	      for( j = 1; j < nactions; ++j )
134  	      {
135  	         SCIP_Real weight = weights[j];
136  	
137  	         /* select the action that maximizes the reward, breaking ties by action priorities */
138  	         if( maxweight < weight
139  	               || (weight >= maxweight - EPSGREEDY_SMALL && priorities[j] > priorities[*selection] ) )
140  	         {
141  	            *selection = j;
142  	            maxweight = weight;
143  	         }
144  	      }
145  	   }
146  	   else if( ! banditdata->usemodification ) /* use normal eps greedy */
147  	   {
148  	      /* play one of the actions at random */
149  	      *selection = SCIPrandomGetInt(rng, 0, nactions - 1);
150  	   }
151  	   else /* pick an action w.r.t. the distributions defined by its weights */
152  	   {
153  	      int j;
154  	      SCIP_Real sum;
155  	      SCIP_Real weightsum;
156  	      SCIP_Real* weights = banditdata->weights;
157  	
158  	      weightsum = 0.0;
159  	      for( j = 0; j < nactions; ++j )
160  	         weightsum += banditdata->weights[j];
161  	
162  	      /* pick a random number between 0.0 and sum of weights */
163  	      randnr = SCIPrandomGetReal(rng, 0.0, weightsum);
164  	
165  	      /* pick action w.r.t. the weights distribution */
166  	      sum = 0.0;
167  	      *selection = -1;
168  	      for( j = 0; j < nactions - 1; ++j )
169  	      {
170  	         sum += weights[j];
171  	
172  	         if( sum >= randnr )
173  	         {
174  	            *selection = j;
175  	            break;
176  	         }
177  	      }
178  	
179  	      if( *selection < 0 )
180  	         *selection = nactions - 1;
181  	   }
182  	
183  	   return SCIP_OKAY;
184  	}
185  	
186  	/** update callback for bandit algorithm */
187  	SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateEpsgreedy)
188  	{  /*lint --e{715}*/
189  	   SCIP_BANDITDATA* banditdata;
190  	
191  	   assert(bandit != NULL);
192  	
193  	   banditdata = SCIPbanditGetData(bandit);
194  	   assert(banditdata != NULL);
195  	
196  	   /* increase the selection count */
197  	   ++banditdata->sels[selection];
198  	
199  	   /* the very first observation is directly stored as weight for both average or exponential decay */
200  	   if( banditdata->sels[selection] == 1 )
201  	      banditdata->weights[selection] = score;
202  	   else
203  	   {
204  	      /* use exponentially decreasing weights for older observations */
205  	      if( banditdata->preferrecent && banditdata->sels[selection] > banditdata->avglim )
206  	      {
207  	         /* decrease old weights by decay factor */
208  	         banditdata->weights[selection] *= banditdata->decayfactor;
209  	         banditdata->weights[selection] += (1.0 - banditdata->decayfactor) * score;
210  	      }
211  	      else
212  	      {
213  	         /* update average score */
214  	         SCIP_Real diff = score - banditdata->weights[selection];
215  	         banditdata->weights[selection] += diff / (SCIP_Real)(banditdata->sels[selection]);
216  	      }
217  	   }
218  	
219  	   return SCIP_OKAY;
220  	}
221  	
222  	/** reset callback for bandit algorithm */
223  	SCIP_DECL_BANDITRESET(SCIPbanditResetEpsgreedy)
224  	{  /*lint --e{715}*/
225  	   SCIP_BANDITDATA* banditdata;
226  	   SCIP_Real* weights;
227  	   int w;
228  	   int nactions;
229  	   SCIP_RANDNUMGEN* rng;
230  	
231  	   assert(bandit != NULL);
232  	
233  	   banditdata = SCIPbanditGetData(bandit);
234  	   assert(banditdata != NULL);
235  	
236  	   weights = banditdata->weights;
237  	   nactions = SCIPbanditGetNActions(bandit);
238  	   assert(weights != NULL);
239  	   assert(banditdata->priorities != NULL);
240  	   assert(nactions > 0);
241  	
242  	   rng = SCIPbanditGetRandnumgen(bandit);
243  	   assert(rng != NULL);
244  	
245  	   /* alter priorities slightly to make them unique */
246  	   if( priorities != NULL )
247  	   {
248  	      for( w = 1; w < nactions; ++w )
249  	      {
250  	         assert(priorities[w] >= 0);
251  	         banditdata->priorities[w] = priorities[w] + SCIPrandomGetReal(rng, -EPSGREEDY_SMALL, EPSGREEDY_SMALL);
252  	      }
253  	   }
254  	   else
255  	   {
256  	      /* use random priorities */
257  	      for( w = 0; w < nactions; ++w )
258  	         banditdata->priorities[w] = SCIPrandomGetReal(rng, 0.0, 1.0);
259  	   }
260  	
261  	   /* reset weights and selection counters to 0 */
262  	   BMSclearMemoryArray(weights, nactions);
263  	   BMSclearMemoryArray(banditdata->sels, nactions);
264  	
265  	   banditdata->nselections = 0;
266  	
267  	   return SCIP_OKAY;
268  	}
269  	
270  	/*
271  	 * interface methods of the Epsilon Greedy bandit algorithm
272  	 */
273  	
274  	/** internal method to create and reset epsilon greedy bandit algorithm */
275  	SCIP_RETCODE SCIPbanditCreateEpsgreedy(
276  	   BMS_BLKMEM*           blkmem,             /**< block memory */
277  	   BMS_BUFMEM*           bufmem,             /**< buffer memory */
278  	   SCIP_BANDITVTABLE*    vtable,             /**< virtual function table with epsilon greedy callbacks */
279  	   SCIP_BANDIT**         epsgreedy,          /**< pointer to store the epsilon greedy bandit algorithm */
280  	   SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
281  	   SCIP_Real             eps,                /**< parameter to increase probability for exploration between all actions */
282  	   SCIP_Bool             usemodification,    /**< TRUE if modified eps greedy should be used */
283  	   SCIP_Bool             preferrecent,       /**< should the weights be updated in an exponentially decaying way? */
284  	   SCIP_Real             decayfactor,        /**< the factor to reduce the weight of older observations if exponential decay is enabled */
285  	   int                   avglim,             /**< nonnegative limit on observation number before the exponential decay starts,
286  	                                              *   only relevant if exponential decay is enabled */
287  	   int                   nactions,           /**< the positive number of possible actions */
288  	   unsigned int          initseed            /**< initial random seed */
289  	   )
290  	{
291  	   SCIP_BANDITDATA* banditdata;
292  	
293  	   SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
294  	   assert(banditdata != NULL);
295  	   assert(eps >= 0.0);
296  	
297  	   SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
298  	   SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->priorities, nactions) );
299  	   SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->sels, nactions) );
300  	   banditdata->eps = eps;
301  	   banditdata->nselections = 0;
302  	   banditdata->usemodification = usemodification;
303  	   banditdata->preferrecent = preferrecent;
304  	   banditdata->decayfactor = decayfactor;
305  	   banditdata->avglim = avglim;
306  	
307  	   SCIP_CALL( SCIPbanditCreate(epsgreedy, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
308  	
309  	   return SCIP_OKAY;
310  	}
311  	
312  	/** create and resets an epsilon greedy bandit algorithm */
313  	SCIP_RETCODE SCIPcreateBanditEpsgreedy(
314  	   SCIP*                 scip,               /**< SCIP data structure */
315  	   SCIP_BANDIT**         epsgreedy,          /**< pointer to store the epsilon greedy bandit algorithm */
316  	   SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
317  	   SCIP_Real             eps,                /**< parameter to increase probability for exploration between all actions */
318  	   SCIP_Bool             usemodification,    /**< TRUE if modified eps greedy should be used */
319  	   SCIP_Bool             preferrecent,       /**< should the weights be updated in an exponentially decaying way? */
320  	   SCIP_Real             decayfactor,        /**< the factor to reduce the weight of older observations if exponential decay is enabled */
321  	   int                   avglim,             /**< nonnegative limit on observation number before the exponential decay starts,
322  	                                              *   only relevant if exponential decay is enabled */
323  	   int                   nactions,           /**< the positive number of possible actions */
324  	   unsigned int          initseed            /**< initial seed for random number generation */
325  	   )
326  	{
327  	   SCIP_BANDITVTABLE* vtable;
328  	   assert(scip != NULL);
329  	   assert(epsgreedy != NULL);
330  	
331  	   vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
332  	   if( vtable == NULL )
333  	   {
334  	      SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
335  	      return SCIP_INVALIDDATA;
336  	   }
337  	
338  	   SCIP_CALL( SCIPbanditCreateEpsgreedy(SCIPblkmem(scip), SCIPbuffer(scip), vtable, epsgreedy,
339  	         priorities, eps, usemodification, preferrecent, decayfactor, avglim, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
340  	
341  	   return SCIP_OKAY;
342  	}
343  	
344  	/** get weights array of epsilon greedy bandit algorithm */
345  	SCIP_Real* SCIPgetWeightsEpsgreedy(
346  	   SCIP_BANDIT*          epsgreedy           /**< epsilon greedy bandit algorithm */
347  	   )
348  	{
349  	   SCIP_BANDITDATA* banditdata;
350  	   assert(epsgreedy != NULL);
351  	   banditdata = SCIPbanditGetData(epsgreedy);
352  	   assert(banditdata != NULL);
353  	
354  	   return banditdata->weights;
355  	}
356  	
357  	/** set epsilon parameter of epsilon greedy bandit algorithm */
358  	void SCIPsetEpsilonEpsgreedy(
359  	   SCIP_BANDIT*          epsgreedy,          /**< epsilon greedy bandit algorithm */
360  	   SCIP_Real             eps                 /**< parameter to increase probability for exploration between all actions */
361  	   )
362  	{
363  	   SCIP_BANDITDATA* banditdata;
364  	   assert(epsgreedy != NULL);
365  	   assert(eps >= 0);
366  	
367  	   banditdata = SCIPbanditGetData(epsgreedy);
368  	
369  	   banditdata->eps = eps;
370  	}
371  	
372  	
373  	/** creates the epsilon greedy bandit algorithm includes it in SCIP */
374  	SCIP_RETCODE SCIPincludeBanditvtableEpsgreedy(
375  	   SCIP*                 scip                /**< SCIP data structure */
376  	   )
377  	{
378  	   SCIP_BANDITVTABLE* banditvtable;
379  	
380  	   SCIP_CALL( SCIPincludeBanditvtable(scip, &banditvtable, BANDIT_NAME,
381  	         SCIPbanditFreeEpsgreedy, SCIPbanditSelectEpsgreedy, SCIPbanditUpdateEpsgreedy, SCIPbanditResetEpsgreedy) );
382  	
383  	   return SCIP_OKAY;
384  	}
385