1    	/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2    	/*                                                                           */
3    	/*                  This file is part of the program and library             */
4    	/*         SCIP --- Solving Constraint Integer Programs                      */
5    	/*                                                                           */
6    	/*  Copyright (c) 2002-2023 Zuse Institute Berlin (ZIB)                      */
7    	/*                                                                           */
8    	/*  Licensed under the Apache License, Version 2.0 (the "License");          */
9    	/*  you may not use this file except in compliance with the License.         */
10   	/*  You may obtain a copy of the License at                                  */
11   	/*                                                                           */
12   	/*      http://www.apache.org/licenses/LICENSE-2.0                           */
13   	/*                                                                           */
14   	/*  Unless required by applicable law or agreed to in writing, software      */
15   	/*  distributed under the License is distributed on an "AS IS" BASIS,        */
16   	/*  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
17   	/*  See the License for the specific language governing permissions and      */
18   	/*  limitations under the License.                                           */
19   	/*                                                                           */
20   	/*  You should have received a copy of the Apache-2.0 license                */
21   	/*  along with SCIP; see the file LICENSE. If not visit scipopt.org.         */
22   	/*                                                                           */
23   	/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
24   	
25   	/**@file   bandit_ucb.c
26   	 * @ingroup OTHER_CFILES
27   	 * @brief  methods for UCB bandit selection
28   	 * @author Gregor Hendel
29   	 */
30   	
31   	/*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
32   	
33   	#include "scip/bandit.h"
34   	#include "scip/bandit_ucb.h"
35   	#include "scip/pub_bandit.h"
36   	#include "scip/pub_message.h"
37   	#include "scip/pub_misc.h"
38   	#include "scip/pub_misc_sort.h"
39   	#include "scip/scip_bandit.h"
40   	#include "scip/scip_mem.h"
41   	#include "scip/scip_randnumgen.h"
42   	
43   	
44   	#define BANDIT_NAME "ucb"
45   	#define NUMEPS 1e-6
46   	
47   	/*
48   	 * Data structures
49   	 */
50   	
51   	/** implementation specific data of UCB bandit algorithm */
52   	struct SCIP_BanditData
53   	{
54   	   int                   nselections;        /**< counter for the number of selections */
55   	   int*                  counter;            /**< array of counters how often every action has been chosen */
56   	   int*                  startperm;          /**< indices for starting permutation */
57   	   SCIP_Real*            meanscores;         /**< array of average scores for the actions */
58   	   SCIP_Real             alpha;              /**< parameter to increase confidence width */
59   	};
60   	
61   	
62   	/*
63   	 * Local methods
64   	 */
65   	
66   	/** data reset method */
67   	static
68   	SCIP_RETCODE dataReset(
69   	   BMS_BUFMEM*           bufmem,             /**< buffer memory */
70   	   SCIP_BANDIT*          ucb,                /**< ucb bandit algorithm */
71   	   SCIP_BANDITDATA*      banditdata,         /**< UCB bandit data structure */
72   	   SCIP_Real*            priorities,         /**< priorities for start permutation, or NULL */
73   	   int                   nactions            /**< number of actions */
74   	   )
75   	{
76   	   int i;
77   	   SCIP_RANDNUMGEN* rng;
78   	
79   	   assert(bufmem != NULL);
80   	   assert(ucb != NULL);
81   	   assert(nactions > 0);
82   	
83   	   /* clear counters and scores */
84   	   BMSclearMemoryArray(banditdata->counter, nactions);
85   	   BMSclearMemoryArray(banditdata->meanscores, nactions);
86   	   banditdata->nselections = 0;
87   	
88   	   rng = SCIPbanditGetRandnumgen(ucb);
89   	   assert(rng != NULL);
90   	
91   	   /* initialize start permutation as identity */
92   	   for( i = 0; i < nactions; ++i )
93   	      banditdata->startperm[i] = i;
94   	
95   	   /* prepare the start permutation in decreasing order of priority */
96   	   if( priorities != NULL )
97   	   {
98   	      SCIP_Real* prioritycopy;
99   	
100  	      SCIP_ALLOC( BMSduplicateBufferMemoryArray(bufmem, &prioritycopy, priorities, nactions) );
101  	
102  	      /* randomly wiggle priorities a little bit to make them unique */
103  	      for( i = 0; i < nactions; ++i )
104  	         prioritycopy[i] += SCIPrandomGetReal(rng, -NUMEPS, NUMEPS);
105  	
106  	      SCIPsortDownRealInt(prioritycopy, banditdata->startperm, nactions);
107  	
108  	      BMSfreeBufferMemoryArray(bufmem, &prioritycopy);
109  	   }
110  	   else
111  	   {
112  	      /* use a random start permutation */
113  	      SCIPrandomPermuteIntArray(rng, banditdata->startperm, 0, nactions);
114  	   }
115  	
116  	   return SCIP_OKAY;
117  	}
118  	
119  	
120  	/*
121  	 * Callback methods of bandit algorithm
122  	 */
123  	
124  	/** callback to free bandit specific data structures */
125  	SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
126  	{  /*lint --e{715}*/
127  	   SCIP_BANDITDATA* banditdata;
128  	   int nactions;
129  	   assert(bandit != NULL);
130  	
131  	   banditdata = SCIPbanditGetData(bandit);
132  	   assert(banditdata != NULL);
133  	   nactions = SCIPbanditGetNActions(bandit);
134  	
135  	   BMSfreeBlockMemoryArray(blkmem, &banditdata->counter, nactions);
136  	   BMSfreeBlockMemoryArray(blkmem, &banditdata->startperm, nactions);
137  	   BMSfreeBlockMemoryArray(blkmem, &banditdata->meanscores, nactions);
138  	   BMSfreeBlockMemory(blkmem, &banditdata);
139  	
140  	   SCIPbanditSetData(bandit, NULL);
141  	
142  	   return SCIP_OKAY;
143  	}
144  	
145  	/** selection callback for bandit selector */
146  	SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
147  	{  /*lint --e{715}*/
148  	   SCIP_BANDITDATA* banditdata;
149  	   int nactions;
150  	   int* counter;
151  	
152  	   assert(bandit != NULL);
153  	   assert(selection != NULL);
154  	
155  	   banditdata = SCIPbanditGetData(bandit);
156  	   assert(banditdata != NULL);
157  	   nactions = SCIPbanditGetNActions(bandit);
158  	
159  	   counter = banditdata->counter;
160  	   /* select the next uninitialized action from the start permutation */
161  	   if( banditdata->nselections < nactions )
162  	   {
163  	      *selection = banditdata->startperm[banditdata->nselections];
164  	      assert(counter[*selection] == 0);
165  	   }
166  	   else
167  	   {
168  	      /* select the action with the highest upper confidence bound */
169  	      SCIP_Real* meanscores;
170  	      SCIP_Real widthfactor;
171  	      SCIP_Real maxucb;
172  	      int i;
173  	      SCIP_RANDNUMGEN* rng = SCIPbanditGetRandnumgen(bandit);
174  	      meanscores = banditdata->meanscores;
175  	
176  	      assert(rng != NULL);
177  	      assert(meanscores != NULL);
178  	
179  	      /* compute the confidence width factor that is common for all actions */
180  	      /* cppcheck-suppress unpreciseMathCall */
181  	      widthfactor = banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections);
182  	      widthfactor = sqrt(widthfactor);
183  	      maxucb = -1.0;
184  	
185  	      /* loop over the actions and determine the maximum upper confidence bound.
186  	       * The upper confidence bound of an action is the sum of its mean score
187  	       * plus a confidence term that decreases with increasing number of observations of
188  	       * this action.
189  	       */
190  	      for( i = 0; i < nactions; ++i )
191  	      {
192  	         SCIP_Real uppercb;
193  	         SCIP_Real rootcount;
194  	         assert(counter[i] > 0);
195  	
196  	         /* compute the upper confidence bound for action i */
197  	         uppercb = meanscores[i];
198  	         rootcount = sqrt((SCIP_Real)counter[i]);
199  	         uppercb += widthfactor / rootcount;
200  	         assert(uppercb > 0);
201  	
202  	         /* update maximum, breaking ties uniformly at random */
203  	         if( EPSGT(uppercb, maxucb, NUMEPS) || (EPSEQ(uppercb, maxucb, NUMEPS) && SCIPrandomGetReal(rng, 0.0, 1.0) >= 0.5) )
204  	         {
205  	            maxucb = uppercb;
206  	            *selection = i;
207  	         }
208  	      }
209  	   }
210  	
211  	   assert(*selection >= 0);
212  	   assert(*selection < nactions);
213  	
214  	   return SCIP_OKAY;
215  	}
216  	
217  	/** update callback for bandit algorithm */
218  	SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
219  	{  /*lint --e{715}*/
220  	   SCIP_BANDITDATA* banditdata;
221  	   SCIP_Real delta;
222  	
223  	   assert(bandit != NULL);
224  	
225  	   banditdata = SCIPbanditGetData(bandit);
226  	   assert(banditdata != NULL);
227  	   assert(selection >= 0);
228  	   assert(selection < SCIPbanditGetNActions(bandit));
229  	
230  	   /* increase the mean by the incremental formula: A_n = A_n-1 + 1/n (a_n - A_n-1) */
231  	   delta = score - banditdata->meanscores[selection];
232  	   ++banditdata->counter[selection];
233  	   banditdata->meanscores[selection] += delta / (SCIP_Real)banditdata->counter[selection];
234  	
235  	   banditdata->nselections++;
236  	
237  	   return SCIP_OKAY;
238  	}
239  	
240  	/** reset callback for bandit algorithm */
241  	SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
242  	{  /*lint --e{715}*/
243  	   SCIP_BANDITDATA* banditdata;
244  	   int nactions;
245  	
246  	   assert(bufmem != NULL);
247  	   assert(bandit != NULL);
248  	
249  	   banditdata = SCIPbanditGetData(bandit);
250  	   assert(banditdata != NULL);
251  	   nactions = SCIPbanditGetNActions(bandit);
252  	
253  	   /* call the data reset for the given priorities */
254  	   SCIP_CALL( dataReset(bufmem, bandit, banditdata, priorities, nactions) );
255  	
256  	   return SCIP_OKAY;
257  	}
258  	
259  	/*
260  	 * bandit algorithm specific interface methods
261  	 */
262  	
263  	/** returns the upper confidence bound of a selected action */
264  	SCIP_Real SCIPgetConfidenceBoundUcb(
265  	   SCIP_BANDIT*          ucb,                /**< UCB bandit algorithm */
266  	   int                   action              /**< index of the queried action */
267  	   )
268  	{
269  	   SCIP_Real uppercb;
270  	   SCIP_BANDITDATA* banditdata;
271  	   int nactions;
272  	
273  	   assert(ucb != NULL);
274  	   banditdata = SCIPbanditGetData(ucb);
275  	   nactions = SCIPbanditGetNActions(ucb);
276  	   assert(action < nactions);
277  	
278  	   /* since only scores between 0 and 1 are allowed, 1.0 is a sure upper confidence bound */
279  	   if( banditdata->nselections < nactions )
280  	      return 1.0;
281  	
282  	   /* the bandit algorithm must have picked every action once */
283  	   assert(banditdata->counter[action] > 0);
284  	   uppercb = banditdata->meanscores[action];
285  	
286  	   /* cppcheck-suppress unpreciseMathCall */
287  	   uppercb += sqrt(banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections) / (SCIP_Real)banditdata->counter[action]);
288  	
289  	   return uppercb;
290  	}
291  	
292  	/** return start permutation of the UCB bandit algorithm */
293  	int* SCIPgetStartPermutationUcb(
294  	   SCIP_BANDIT*          ucb                 /**< UCB bandit algorithm */
295  	   )
296  	{
297  	   SCIP_BANDITDATA* banditdata = SCIPbanditGetData(ucb);
298  	
299  	   assert(banditdata != NULL);
300  	
301  	   return banditdata->startperm;
302  	}
303  	
304  	/** internal method to create and reset UCB bandit algorithm */
305  	SCIP_RETCODE SCIPbanditCreateUcb(
306  	   BMS_BLKMEM*           blkmem,             /**< block memory */
307  	   BMS_BUFMEM*           bufmem,             /**< buffer memory */
308  	   SCIP_BANDITVTABLE*    vtable,             /**< virtual function table for UCB bandit algorithm */
309  	   SCIP_BANDIT**         ucb,                /**< pointer to store bandit algorithm */
310  	   SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
311  	   SCIP_Real             alpha,              /**< parameter to increase confidence width */
312  	   int                   nactions,           /**< the positive number of actions for this bandit algorithm */
313  	   unsigned int          initseed            /**< initial random seed */
314  	   )
315  	{
316  	   SCIP_BANDITDATA* banditdata;
317  	
318  	   if( alpha < 0.0 )
319  	   {
320  	      SCIPerrorMessage("UCB requires nonnegative alpha parameter, have %f\n", alpha);
321  	      return SCIP_INVALIDDATA;
322  	   }
323  	
324  	   SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
325  	   assert(banditdata != NULL);
326  	
327  	   SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->counter, nactions) );
328  	   SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->startperm, nactions) );
329  	   SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->meanscores, nactions) );
330  	
331  	   banditdata->alpha = alpha;
332  	
333  	   SCIP_CALL( SCIPbanditCreate(ucb, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
334  	
335  	   return SCIP_OKAY;
336  	}
337  	
338  	/** create and reset UCB bandit algorithm */
339  	SCIP_RETCODE SCIPcreateBanditUcb(
340  	   SCIP*                 scip,               /**< SCIP data structure */
341  	   SCIP_BANDIT**         ucb,                /**< pointer to store bandit algorithm */
342  	   SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
343  	   SCIP_Real             alpha,              /**< parameter to increase confidence width */
344  	   int                   nactions,           /**< the positive number of actions for this bandit algorithm */
345  	   unsigned int          initseed            /**< initial random number seed */
346  	   )
347  	{
348  	   SCIP_BANDITVTABLE* vtable;
349  	
350  	   vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
351  	   if( vtable == NULL )
352  	   {
353  	      SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
354  	      return SCIP_INVALIDDATA;
355  	   }
356  	
357  	   SCIP_CALL( SCIPbanditCreateUcb(SCIPblkmem(scip), SCIPbuffer(scip), vtable, ucb,
358  	         priorities, alpha, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
359  	
360  	   return SCIP_OKAY;
361  	}
362  	
363  	/** include virtual function table for UCB bandit algorithms */
364  	SCIP_RETCODE SCIPincludeBanditvtableUcb(
365  	   SCIP*                 scip                /**< SCIP data structure */
366  	   )
367  	{
368  	   SCIP_BANDITVTABLE* vtable;
369  	
370  	   SCIP_CALL( SCIPincludeBanditvtable(scip, &vtable, BANDIT_NAME,
371  	         SCIPbanditFreeUcb, SCIPbanditSelectUcb, SCIPbanditUpdateUcb, SCIPbanditResetUcb) );
372  	   assert(vtable != NULL);
373  	
374  	   return SCIP_OKAY;
375  	}
376