mxnet
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
object_pool.h
Go to the documentation of this file.
1 
4 #ifndef MXNET_COMMON_OBJECT_POOL_H_
5 #define MXNET_COMMON_OBJECT_POOL_H_
6 #include <dmlc/logging.h>
7 #include <cstdlib>
8 #include <mutex>
9 #include <utility>
10 #include <vector>
11 
12 namespace mxnet {
13 namespace common {
17 template <typename T>
18 class ObjectPool {
19  public:
23  ~ObjectPool();
28  template <typename... Args>
29  T* New(Args&&... args);
36  void Delete(T* ptr);
37 
42  static ObjectPool* Get();
43 
48  static std::shared_ptr<ObjectPool> _GetSharedRef();
49 
50  private:
54  struct LinkedList {
55 #if defined(_MSC_VER)
56  T t;
57  LinkedList* next{nullptr};
58 #else
59  union {
60  T t;
61  LinkedList* next{nullptr};
62  };
63 #endif
64  };
70  constexpr static std::size_t kPageSize = 1 << 12;
72  std::mutex m_;
76  LinkedList* head_{nullptr};
80  std::vector<void*> allocated_;
84  ObjectPool();
90  void AllocateChunk();
91  DISALLOW_COPY_AND_ASSIGN(ObjectPool);
92 }; // class ObjectPool
93 
97 template <typename T>
103  template <typename... Args>
104  static T* New(Args&&... args);
111  static void Delete(T* ptr);
112 }; // struct ObjectPoolAllocatable
113 
114 template <typename T>
116  // TODO(hotpxl): mind destruction order
117  // for (auto i : allocated_) {
118  // free(i);
119  // }
120 }
121 
122 template <typename T>
123 template <typename... Args>
124 T* ObjectPool<T>::New(Args&&... args) {
125  LinkedList* ret;
126  {
127  std::lock_guard<std::mutex> lock{m_};
128  if (head_->next == nullptr) {
129  AllocateChunk();
130  }
131  ret = head_;
132  head_ = head_->next;
133  }
134  return new (static_cast<void*>(ret)) T(std::forward<Args>(args)...);
135 }
136 
137 template <typename T>
138 void ObjectPool<T>::Delete(T* ptr) {
139  ptr->~T();
140  auto linked_list_ptr = reinterpret_cast<LinkedList*>(ptr);
141  {
142  std::lock_guard<std::mutex> lock{m_};
143  linked_list_ptr->next = head_;
144  head_ = linked_list_ptr;
145  }
146 }
147 
148 template <typename T>
150  return _GetSharedRef().get();
151 }
152 
153 template <typename T>
154 std::shared_ptr<ObjectPool<T> > ObjectPool<T>::_GetSharedRef() {
155  static std::shared_ptr<ObjectPool<T> > inst_ptr(new ObjectPool<T>());
156  return inst_ptr;
157 }
158 
159 template <typename T>
161  AllocateChunk();
162 }
163 
164 template <typename T>
165 void ObjectPool<T>::AllocateChunk() {
166  static_assert(sizeof(LinkedList) <= kPageSize, "Object too big.");
167  static_assert(sizeof(LinkedList) % alignof(LinkedList) == 0, "ObjectPooll Invariant");
168  static_assert(alignof(LinkedList) % alignof(T) == 0, "ObjectPooll Invariant");
169  static_assert(kPageSize % alignof(LinkedList) == 0, "ObjectPooll Invariant");
170  void* new_chunk_ptr;
171 #ifdef _MSC_VER
172  new_chunk_ptr = _aligned_malloc(kPageSize, kPageSize);
173  CHECK_NE(new_chunk_ptr, NULL) << "Allocation failed";
174 #else
175  int ret = posix_memalign(&new_chunk_ptr, kPageSize, kPageSize);
176  CHECK_EQ(ret, 0) << "Allocation failed";
177 #endif
178  allocated_.emplace_back(new_chunk_ptr);
179  auto new_chunk = static_cast<LinkedList*>(new_chunk_ptr);
180  auto size = kPageSize / sizeof(LinkedList);
181  for (std::size_t i = 0; i < size - 1; ++i) {
182  new_chunk[i].next = &new_chunk[i + 1];
183  }
184  new_chunk[size - 1].next = head_;
185  head_ = new_chunk;
186 }
187 
188 template <typename T>
189 template <typename... Args>
190 T* ObjectPoolAllocatable<T>::New(Args&&... args) {
191  return ObjectPool<T>::Get()->New(std::forward<Args>(args)...);
192 }
193 
194 template <typename T>
196  ObjectPool<T>::Get()->Delete(ptr);
197 }
198 
199 } // namespace common
200 } // namespace mxnet
201 #endif // MXNET_COMMON_OBJECT_POOL_H_
static void Delete(T *ptr)
Delete an existing object.
Definition: object_pool.h:195
static T * New(Args &&...args)
Create new object.
Definition: object_pool.h:190
T * New(Args &&...args)
Create new object.
Definition: object_pool.h:124
static ObjectPool * Get()
Get singleton instance of pool.
Definition: object_pool.h:149
static std::shared_ptr< ObjectPool > _GetSharedRef()
Get a shared ptr of the singleton instance of pool.
Definition: object_pool.h:154
void Delete(T *ptr)
Delete an existing object.
Definition: object_pool.h:138
Helper trait class for easy allocation and deallocation.
Definition: object_pool.h:98
~ObjectPool()
Destructor.
Definition: object_pool.h:115
Object pool for fast allocation and deallocation.
Definition: object_pool.h:18