10 #if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
18 #include <unordered_map>
38 MXRtc(
const std::string& name,
39 std::vector<std::pair<std::string, NDArray> >
const& input,
40 std::vector<std::pair<std::string, NDArray> >
const& output,
41 const std::string& kernel);
53 void push(std::vector<NDArray>
const& input,
54 std::vector<NDArray>
const& output,
55 unsigned int grid_dim_X,
56 unsigned int grid_dim_Y,
57 unsigned int grid_dim_Z,
58 unsigned int block_dim_X,
59 unsigned int block_dim_Y,
60 unsigned int block_dim_Z);
63 static const std::string str_type;
64 static std::unordered_map<std::string, char*> kernel_registry;
67 index_t num_input_, num_output_;
70 std::unordered_map<int, CUmodule> module_;
71 std::unordered_map<int, CUfunction> func_;
76 std::string decorate(
const std::string& name,
77 std::vector<std::pair<std::string, NDArray> >
const& input,
78 std::vector<std::pair<std::string, NDArray> >
const& output,
79 const std::string kernel);
83 char* compile(
const std::string& name,
const std::string& code);
88 #endif // MXNET_USE_CUDA && MXNET_USE_NVRTC
89 #endif // MXNET_MXRTC_H_
NDArray interface that handles array arithematics.
configuation of mxnet as well as basic data structure.
mshadow::index_t index_t
index type usually use unsigned
Definition: base.h:80