You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Evaluates the GRU4Rec network wrt. recommendation accuracy measured by recall@N and MRR@N.
14
+
15
+
Parameters
16
+
--------
17
+
model : A trained GRU4Rec model.
18
+
train_data : It contains the transactions of the train set. In evaluation phrase, this is used to build item-to-id map.
19
+
test_data : It contains the transactions of the test set. It has one column for session IDs, one for item IDs and one for the timestamp of the events (unix timestamps).
20
+
cut-off : int
21
+
Cut-off value (i.e. the length of the recommendation list; N for recall@N and MRR@N). Defauld value is 20.
22
+
batch_size : int
23
+
Number of events bundled into a batch during evaluation. Speeds up evaluation. If it is set high, the memory consumption increases. Default value is 100.
24
+
session_key : string
25
+
Header of the session ID column in the input file (default: 'SessionId')
26
+
item_key : string
27
+
Header of the item ID column in the input file (default: 'ItemId')
28
+
time_key : string
29
+
Header of the timestamp column in the input file (default: 'Time')
Gives predicton scores for a selected set of items. Can be used in batch mode to predict for multiple independent events (i.e. events of different sessions) at once and thus speed up evaluation.
254
+
255
+
If the session ID at a given coordinate of the session_ids parameter remains the same during subsequent calls of the function, the corresponding hidden state of the network will be kept intact (i.e. that's how one can predict an item to a session).
256
+
If it changes, the hidden state of the network is reset to zeros.
257
+
258
+
Parameters
259
+
--------
260
+
session_ids : 1D array
261
+
Contains the session IDs of the events of the batch. Its length must equal to the prediction batch size (batch param).
262
+
input_item_ids : 1D array
263
+
Contains the item IDs of the events of the batch. Every item ID must be must be in the training data of the network. Its length must equal to the prediction batch size (batch param).
264
+
batch : int
265
+
Prediction batch size.
266
+
267
+
Returns
268
+
--------
269
+
out : pandas.DataFrame
270
+
Prediction scores for selected items for every event of the batch.
271
+
Columns: events of the batch; rows: items. Rows are indexed by the item IDs.
272
+
273
+
'''
274
+
ifbatch!=self.batch_size:
275
+
raiseException('Predict batch size({}) must match train batch size({})'.format(batch, self.batch_size))
0 commit comments