forked from huawei-noah/HEBO
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtask_base.py
74 lines (55 loc) · 2.18 KB
/
task_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify it under
# the terms of the MIT license.
# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE. See the MIT License for more details.
from abc import ABC, abstractmethod
from typing import Optional, List, Callable, Dict, Any
import numpy as np
import pandas as pd
import torch
from mcbo.search_space import SearchSpace
class TaskBase(ABC):
""" Abstract class to define optimization (** MINIMISATION **) tasks """
def __init__(self, **kwargs):
self.kwargs = kwargs
self._n_bb_evals = 0
@property
@abstractmethod
def name(self) -> str:
"""
Returns:
A string correponding to the name of the task
"""
return 'Task Name'
@abstractmethod
def evaluate(self, x: pd.DataFrame) -> np.ndarray:
"""
Function to compute the problem specific black-box function.
Args:
x: dataframe containing the points at which the black-box should be evaluated.
Shape: (batch_size, num_dims), where num_dims is the dimensionality of the problem and batch_size is the
batch
Returns:
2D numpy array containing evaluated black-box values at the input x. Shape: (batch_size, 1).
"""
pass
@property
def num_func_evals(self):
return self._n_bb_evals
def restart(self):
self._n_bb_evals = 0
@property
def input_constraints(self) -> Optional[List[Callable[[Dict], bool]]]:
return None
@abstractmethod
def search_space_params(self) -> List[Dict[str, Any]]:
pass
def get_search_space(self, dtype: torch.dtype=torch.float64) -> SearchSpace:
return SearchSpace(params=self.search_space_params(), dtype=dtype)
def increment_n_evals(self, n: int):
self._n_bb_evals += n
def __call__(self, x: pd.DataFrame) -> np.ndarray:
self.increment_n_evals(n=len(x))
return self.evaluate(x.copy())