mxnet
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
mxrtc.h
Go to the documentation of this file.
1 
7 #ifndef MXNET_MXRTC_H_
8 #define MXNET_MXRTC_H_
9 #include "./base.h"
10 #if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
11 #include <nvrtc.h>
12 #include <cuda.h>
13 
14 #include <vector>
15 #include <string>
16 #include <memory>
17 #include <utility>
18 #include <unordered_map>
19 #include "./ndarray.h"
20 
21 namespace mxnet {
22 
26 class MXRtc {
27  public:
38  MXRtc(const std::string& name,
39  std::vector<std::pair<std::string, NDArray> > const& input,
40  std::vector<std::pair<std::string, NDArray> > const& output,
41  const std::string& kernel);
53  void push(std::vector<NDArray> const& input,
54  std::vector<NDArray> const& output,
55  unsigned int grid_dim_X,
56  unsigned int grid_dim_Y,
57  unsigned int grid_dim_Z,
58  unsigned int block_dim_X,
59  unsigned int block_dim_Y,
60  unsigned int block_dim_Z);
61 
62  private:
63  static const std::string str_type;
64  static std::unordered_map<std::string, char*> kernel_registry;
65 
66  std::string name_;
67  index_t num_input_, num_output_;
68  std::string code_;
69  char* ptx_;
70  std::unordered_map<int, CUmodule> module_;
71  std::unordered_map<int, CUfunction> func_;
72 
76  std::string decorate(const std::string& name,
77  std::vector<std::pair<std::string, NDArray> > const& input,
78  std::vector<std::pair<std::string, NDArray> > const& output,
79  const std::string kernel);
83  char* compile(const std::string& name, const std::string& code);
84 };
85 
86 } // namespace mxnet
87 
88 #endif // MXNET_USE_CUDA && MXNET_USE_NVRTC
89 #endif // MXNET_MXRTC_H_
NDArray interface that handles array arithematics.
configuation of mxnet as well as basic data structure.
mshadow::index_t index_t
index type usually use unsigned
Definition: base.h:80