1    	/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2    	/*                                                                           */
3    	/*                  This file is part of the program and library             */
4    	/*         SCIP --- Solving Constraint Integer Programs                      */
5    	/*                                                                           */
6    	/*  Copyright (c) 2002-2023 Zuse Institute Berlin (ZIB)                      */
7    	/*                                                                           */
8    	/*  Licensed under the Apache License, Version 2.0 (the "License");          */
9    	/*  you may not use this file except in compliance with the License.         */
10   	/*  You may obtain a copy of the License at                                  */
11   	/*                                                                           */
12   	/*      http://www.apache.org/licenses/LICENSE-2.0                           */
13   	/*                                                                           */
14   	/*  Unless required by applicable law or agreed to in writing, software      */
15   	/*  distributed under the License is distributed on an "AS IS" BASIS,        */
16   	/*  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
17   	/*  See the License for the specific language governing permissions and      */
18   	/*  limitations under the License.                                           */
19   	/*                                                                           */
20   	/*  You should have received a copy of the Apache-2.0 license                */
21   	/*  along with SCIP; see the file LICENSE. If not visit scipopt.org.         */
22   	/*                                                                           */
23   	/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
24   	
25   	/**@file   bandit.c
26   	 * @ingroup OTHER_CFILES
27   	 * @brief  internal API of bandit algorithms and bandit virtual function tables
28   	 * @author Gregor Hendel
29   	 */
30   	
31   	/*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
32   	
33   	#include <assert.h>
34   	#include <string.h>
35   	#include "scip/bandit.h"
36   	#include "scip/pub_bandit.h"
37   	#include "scip/struct_bandit.h"
38   	#include "scip/struct_set.h"
39   	#include "scip/set.h"
40   	
41   	/** creates and resets bandit algorithm */
42   	SCIP_RETCODE SCIPbanditCreate(
43   	   SCIP_BANDIT**         bandit,             /**< pointer to bandit algorithm data structure */
44   	   SCIP_BANDITVTABLE*    banditvtable,       /**< virtual table for this bandit algorithm */
45   	   BMS_BLKMEM*           blkmem,             /**< block memory for parameter settings */
46   	   BMS_BUFMEM*           bufmem,             /**< buffer memory */
47   	   SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
48   	   int                   nactions,           /**< the positive number of actions for this bandit */
49   	   unsigned int          initseed,           /**< initial seed for random number generation */
50   	   SCIP_BANDITDATA*      banditdata          /**< algorithm specific bandit data */
51   	   )
52   	{
53   	   SCIP_BANDIT* banditptr;
54   	   assert(bandit != NULL);
55   	   assert(banditvtable != NULL);
56   	
57   	   /* the number of actions must be positive */
58   	   if( nactions <= 0 )
59   	   {
60   	      SCIPerrorMessage("Cannot create bandit selector with %d <= 0 actions\n", nactions);
61   	
62   	      return SCIP_INVALIDDATA;
63   	   }
64   	
65   	   SCIP_ALLOC( BMSallocBlockMemory(blkmem, bandit) );
66   	   assert(*bandit != NULL);
67   	   banditptr = *bandit;
68   	   banditptr->vtable = banditvtable;
69   	   banditptr->data = banditdata;
70   	   banditptr->nactions = nactions;
71   	
72   	   SCIP_CALL( SCIPrandomCreate(&banditptr->rng, blkmem, initseed) );
73   	
74   	   SCIP_CALL( SCIPbanditReset(bufmem, banditptr, priorities, initseed) );
75   	
76   	   return SCIP_OKAY;
77   	}
78   	
79   	/** calls destructor and frees memory of bandit algorithm */
80   	SCIP_RETCODE SCIPbanditFree(
81   	   BMS_BLKMEM*           blkmem,             /**< block memory */
82   	   SCIP_BANDIT**         bandit              /**< pointer to bandit algorithm data structure */
83   	   )
84   	{
85   	   SCIP_BANDIT* banditptr;
86   	   SCIP_BANDITVTABLE* vtable;
87   	   assert(bandit != NULL);
88   	   assert(*bandit != NULL);
89   	
90   	   banditptr = *bandit;
91   	   vtable = banditptr->vtable;
92   	   assert(vtable != NULL);
93   	
94   	   /* call bandit specific data destructor */
95   	   if( vtable->banditfree != NULL )
96   	   {
97   	      SCIP_CALL( vtable->banditfree(blkmem, banditptr) );
98   	   }
99   	
100  	   /* free random number generator */
101  	   SCIPrandomFree(&banditptr->rng, blkmem);
102  	
103  	   BMSfreeBlockMemory(blkmem, bandit);
104  	
105  	   return SCIP_OKAY;
106  	}
107  	
108  	/** reset the bandit algorithm */
109  	SCIP_RETCODE SCIPbanditReset(
110  	   BMS_BUFMEM*           bufmem,             /**< buffer memory */
111  	   SCIP_BANDIT*          bandit,             /**< pointer to bandit algorithm data structure */
112  	   SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
113  	   unsigned int          seed                /**< initial seed for random number generation */
114  	   )
115  	{
116  	   SCIP_BANDITVTABLE* vtable;
117  	
118  	   assert(bandit != NULL);
119  	   assert(bufmem != NULL);
120  	
121  	   vtable = bandit->vtable;
122  	   assert(vtable != NULL);
123  	   assert(vtable->banditreset != NULL);
124  	
125  	   /* test if the priorities are nonnegative */
126  	   if( priorities != NULL )
127  	   {
128  	      int i;
129  	
130  	      assert(SCIPbanditGetNActions(bandit) > 0);
131  	
132  	      for( i = 0; i < SCIPbanditGetNActions(bandit); ++i )
133  	      {
134  	         if( priorities[i] < 0 )
135  	         {
136  	            SCIPerrorMessage("Negative priority for action %d\n", i);
137  	
138  	            return SCIP_INVALIDDATA;
139  	         }
140  	      }
141  	   }
142  	
143  	   /* reset the random seed of the bandit algorithm */
144  	   SCIPrandomSetSeed(bandit->rng, seed);
145  	
146  	   /* call the reset callback of the bandit algorithm */
147  	   SCIP_CALL( vtable->banditreset(bufmem, bandit, priorities) );
148  	
149  	   return SCIP_OKAY;
150  	}
151  	
152  	/** select the next action */
153  	SCIP_RETCODE SCIPbanditSelect(
154  	   SCIP_BANDIT*          bandit,             /**< bandit algorithm data structure */
155  	   int*                  action              /**< pointer to store the selected action */
156  	   )
157  	{
158  	   assert(bandit != NULL);
159  	   assert(action != NULL);
160  	
161  	   *action = -1;
162  	
163  	   assert(bandit->vtable->banditselect != NULL);
164  	
165  	   SCIP_CALL( bandit->vtable->banditselect(bandit, action) );
166  	
167  	   assert(*action >= 0);
168  	   assert(*action < SCIPbanditGetNActions(bandit));
169  	
170  	   return SCIP_OKAY;
171  	}
172  	
173  	/** update the score of the selected action */
174  	SCIP_RETCODE SCIPbanditUpdate(
175  	   SCIP_BANDIT*          bandit,             /**< bandit algorithm data structure */
176  	   int                   action,             /**< index of action for which the score should be updated */
177  	   SCIP_Real             score               /**< observed gain of the i'th action */
178  	   )
179  	{
180  	   assert(bandit != NULL);
181  	   assert(0 <= action && action < SCIPbanditGetNActions(bandit));
182  	   assert(bandit->vtable->banditupdate != NULL);
183  	
184  	   SCIP_CALL( bandit->vtable->banditupdate(bandit, action, score) );
185  	
186  	   return SCIP_OKAY;
187  	}
188  	
189  	/** get data of this bandit algorithm */
190  	SCIP_BANDITDATA* SCIPbanditGetData(
191  	   SCIP_BANDIT*          bandit              /**< bandit algorithm data structure */
192  	   )
193  	{
194  	   assert(bandit != NULL);
195  	
196  	   return bandit->data;
197  	}
198  	
199  	/** set the data of this bandit algorithm */
200  	void SCIPbanditSetData(
201  	   SCIP_BANDIT*          bandit,             /**< bandit algorithm data structure */
202  	   SCIP_BANDITDATA*      banditdata          /**< bandit algorihm specific data, or NULL */
203  	   )
204  	{
205  	   assert(bandit != NULL);
206  	
207  	   bandit->data = banditdata;
208  	}
209  	
210  	/** internal method to create a bandit VTable */
211  	static
212  	SCIP_RETCODE doBanditvtableCreate(
213  	   SCIP_BANDITVTABLE**   banditvtable,       /**< pointer to virtual table for bandit algorithm */
214  	   const char*           name,               /**< a name for the algorithm represented by this vtable */
215  	   SCIP_DECL_BANDITFREE  ((*banditfree)),    /**< callback to free bandit specific data structures */
216  	   SCIP_DECL_BANDITSELECT((*banditselect)),  /**< selection callback for bandit selector */
217  	   SCIP_DECL_BANDITUPDATE((*banditupdate)),  /**< update callback for bandit algorithms */
218  	   SCIP_DECL_BANDITRESET ((*banditreset))    /**< update callback for bandit algorithms */
219  	   )
220  	{
221  	   SCIP_BANDITVTABLE* banditvtableptr;
222  	
223  	   assert(banditvtable != NULL);
224  	   assert(name != NULL);
225  	   assert(banditfree != NULL);
226  	   assert(banditselect != NULL);
227  	   assert(banditupdate != NULL);
228  	   assert(banditreset != NULL);
229  	
230  	   /* allocate memory for this virtual function table */
231  	   SCIP_ALLOC( BMSallocMemory(banditvtable) );
232  	   BMSclearMemory(*banditvtable);
233  	
234  	   SCIP_ALLOC( BMSduplicateMemoryArray(&(*banditvtable)->name, name, strlen(name)+1) );
235  	   banditvtableptr = *banditvtable;
236  	   banditvtableptr->banditfree = banditfree;
237  	   banditvtableptr->banditselect = banditselect;
238  	   banditvtableptr->banditupdate = banditupdate;
239  	   banditvtableptr->banditreset = banditreset;
240  	
241  	   return SCIP_OKAY;
242  	}
243  	
244  	/** create a bandit VTable for bandit algorithm callback functions */
245  	SCIP_RETCODE SCIPbanditvtableCreate(
246  	   SCIP_BANDITVTABLE**   banditvtable,       /**< pointer to virtual table for bandit algorithm */
247  	   const char*           name,               /**< a name for the algorithm represented by this vtable */
248  	   SCIP_DECL_BANDITFREE  ((*banditfree)),    /**< callback to free bandit specific data structures */
249  	   SCIP_DECL_BANDITSELECT((*banditselect)),  /**< selection callback for bandit selector */
250  	   SCIP_DECL_BANDITUPDATE((*banditupdate)),  /**< update callback for bandit algorithms */
251  	   SCIP_DECL_BANDITRESET ((*banditreset))    /**< update callback for bandit algorithms */
252  	   )
253  	{
254  	   assert(banditvtable != NULL);
255  	   assert(name != NULL);
256  	   assert(banditfree != NULL);
257  	   assert(banditselect != NULL);
258  	   assert(banditupdate != NULL);
259  	   assert(banditreset != NULL);
260  	
261  	   SCIP_CALL_FINALLY( doBanditvtableCreate(banditvtable, name, banditfree, banditselect, banditupdate, banditreset),
262  	      SCIPbanditvtableFree(banditvtable) );
263  	
264  	   return SCIP_OKAY;
265  	}
266  	
267  	
268  	/** free a bandit virtual table for bandit algorithm callback functions */
269  	void SCIPbanditvtableFree(
270  	   SCIP_BANDITVTABLE**   banditvtable        /**< pointer to virtual table for bandit algorithm */
271  	   )
272  	{
273  	   assert(banditvtable != NULL);
274  	   if( *banditvtable == NULL )
275  	      return;
276  	
277  	   BMSfreeMemoryArrayNull(&(*banditvtable)->name);
278  	   BMSfreeMemory(banditvtable);
279  	}
280  	
281  	/** return the name of this bandit virtual function table */
282  	const char* SCIPbanditvtableGetName(
283  	   SCIP_BANDITVTABLE*    banditvtable        /**< virtual table for bandit algorithm */
284  	   )
285  	{
286  	   assert(banditvtable != NULL);
287  	
288  	   return banditvtable->name;
289  	}
290  	
291  	
292  	/** return the random number generator of a bandit algorithm */
293  	SCIP_RANDNUMGEN* SCIPbanditGetRandnumgen(
294  	   SCIP_BANDIT*          bandit              /**< bandit algorithm data structure */
295  	   )
296  	{
297  	   assert(bandit != NULL);
298  	
299  	   return bandit->rng;
300  	}
301  	
302  	/** return number of actions of this bandit algorithm */
303  	int SCIPbanditGetNActions(
304  	   SCIP_BANDIT*          bandit              /**< bandit algorithm data structure */
305  	   )
306  	{
307  	   assert(bandit != NULL);
308  	
309  	   return bandit->nactions;
310  	}
311