8 #ifndef XGBOOST_COMMON_SURVIVAL_UTIL_H_ 9 #define XGBOOST_COMMON_SURVIVAL_UTIL_H_ 47 constexpr
double kEps = 1e-12;
53 inline double Clip(
double x,
double x_min,
double x_max) {
63 template<
typename Distribution>
67 template<
typename Distribution>
80 DMLC_DECLARE_FIELD(aft_loss_distribution)
85 .describe(
"Choice of distribution for the noise term in " 86 "Accelerated Failure Time model");
87 DMLC_DECLARE_FIELD(aft_loss_distribution_scale)
89 .describe(
"Scaling factor used to scale the distribution in " 90 "Accelerated Failure Time model");
95 template<
typename Distribution>
98 double Loss(
double y_lower,
double y_upper,
double y_pred,
double sigma) {
99 const double log_y_lower = log(y_lower);
100 const double log_y_upper = log(y_upper);
104 if (y_lower == y_upper) {
105 const double z = (log_y_lower - y_pred) / sigma;
106 const double pdf = Distribution::PDF(z);
108 cost = -log(fmax(pdf / (sigma * y_lower),
aft::kEps));
110 double z_u, z_l, cdf_u, cdf_l;
111 if (isinf(y_upper)) {
114 z_u = (log_y_upper - y_pred) / sigma;
115 cdf_u = Distribution::CDF(z_u);
117 if (y_lower <= 0.0) {
120 z_l = (log_y_lower - y_pred) / sigma;
121 cdf_l = Distribution::CDF(z_l);
124 cost = -log(fmax(cdf_u - cdf_l,
aft::kEps));
131 double Gradient(
double y_lower,
double y_upper,
double y_pred,
double sigma) {
132 const double log_y_lower = log(y_lower);
133 const double log_y_upper = log(y_upper);
134 double numerator, denominator, gradient;
138 if (y_lower == y_upper) {
139 const double z = (log_y_lower - y_pred) / sigma;
140 const double pdf = Distribution::PDF(z);
141 const double grad_pdf = Distribution::GradPDF(z);
143 numerator = grad_pdf;
144 denominator = sigma * pdf;
147 double z_u = 0.0, z_l = 0.0, pdf_u, pdf_l, cdf_u, cdf_l;
149 if (isinf(y_upper)) {
154 z_u = (log_y_upper - y_pred) / sigma;
155 pdf_u = Distribution::PDF(z_u);
156 cdf_u = Distribution::CDF(z_u);
158 if (y_lower <= 0.0) {
163 z_l = (log_y_lower - y_pred) / sigma;
164 pdf_l = Distribution::PDF(z_l);
165 cdf_l = Distribution::CDF(z_l);
167 z_sign = (z_u > 0 || z_l > 0);
168 numerator = pdf_u - pdf_l;
169 denominator = sigma * (cdf_u - cdf_l);
171 gradient = numerator / denominator;
172 if (denominator <
aft::kEps && (isnan(gradient) || isinf(gradient))) {
173 gradient = aft::GetLimitGradAtInfPred<Distribution>(censor_type, z_sign, sigma);
180 double Hessian(
double y_lower,
double y_upper,
double y_pred,
double sigma) {
181 const double log_y_lower = log(y_lower);
182 const double log_y_upper = log(y_upper);
183 double numerator, denominator, hessian;
187 if (y_lower == y_upper) {
188 const double z = (log_y_lower - y_pred) / sigma;
189 const double pdf = Distribution::PDF(z);
190 const double grad_pdf = Distribution::GradPDF(z);
191 const double hess_pdf = Distribution::HessPDF(z);
193 numerator = -(pdf * hess_pdf - grad_pdf * grad_pdf);
194 denominator = sigma * sigma * pdf * pdf;
197 double z_u = 0.0, z_l = 0.0, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l;
199 if (isinf(y_upper)) {
205 z_u = (log_y_upper - y_pred) / sigma;
206 pdf_u = Distribution::PDF(z_u);
207 cdf_u = Distribution::CDF(z_u);
208 grad_pdf_u = Distribution::GradPDF(z_u);
210 if (y_lower <= 0.0) {
216 z_l = (log_y_lower - y_pred) / sigma;
217 pdf_l = Distribution::PDF(z_l);
218 cdf_l = Distribution::CDF(z_l);
219 grad_pdf_l = Distribution::GradPDF(z_l);
221 const double cdf_diff = cdf_u - cdf_l;
222 const double pdf_diff = pdf_u - pdf_l;
223 const double grad_diff = grad_pdf_u - grad_pdf_l;
224 const double sqrt_denominator = sigma * cdf_diff;
225 z_sign = (z_u > 0 || z_l > 0);
226 numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff);
227 denominator = sqrt_denominator * sqrt_denominator;
229 hessian = numerator / denominator;
230 if (denominator <
aft::kEps && (isnan(hessian) || isinf(hessian))) {
231 hessian = aft::GetLimitHessAtInfPred<Distribution>(censor_type, z_sign, sigma);
243 switch (censor_type) {
253 return std::numeric_limits<double>::quiet_NaN();
259 switch (censor_type) {
261 return 1.0 / (sigma * sigma);
263 return sign ? (1.0 / (sigma * sigma)) :
kMinHessian;
265 return sign ?
kMinHessian : (1.0 / (sigma * sigma));
267 return 1.0 / (sigma * sigma);
269 return std::numeric_limits<double>::quiet_NaN();
275 switch (censor_type) {
277 return sign ? (-1.0 / sigma) : (1.0 / sigma);
279 return sign ? (-1.0 / sigma) : 0.0;
281 return sign ? 0.0 : (1.0 / sigma);
283 return sign ? (-1.0 / sigma) : (1.0 / sigma);
285 return std::numeric_limits<double>::quiet_NaN();
291 switch (censor_type) {
298 return std::numeric_limits<double>::quiet_NaN();
304 switch (censor_type) {
310 return sign ? 0.0 : (1.0 / sigma);
314 return std::numeric_limits<double>::quiet_NaN();
320 switch (censor_type) {
329 return std::numeric_limits<double>::quiet_NaN();
337 #endif // XGBOOST_COMMON_SURVIVAL_UTIL_H_ The AFT loss function.
Definition: survival_util.h:96
static XGBOOST_DEVICE double Loss(double y_lower, double y_upper, double y_pred, double sigma)
Definition: survival_util.h:98
static XGBOOST_DEVICE double Gradient(double y_lower, double y_upper, double y_pred, double sigma)
Definition: survival_util.h:131
Definition: parameter.h:84
XGBOOST_DEVICE double GetLimitHessAtInfPred< LogisticDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:290
constexpr double kMinHessian
Definition: survival_util.h:44
constexpr double kEps
Definition: survival_util.h:47
XGBOOST_DEVICE double GetLimitHessAtInfPred< ExtremeDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:319
float aft_loss_distribution_scale
Scaling factor to be applied to the distribution.
Definition: survival_util.h:78
constexpr double kMaxGradient
Definition: survival_util.h:43
XGBOOST_DEVICE double GetLimitGradAtInfPred< LogisticDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:274
ProbabilityDistributionType
Enum encoding possible choices of probability distribution.
Definition: probability_distribution.h:31
DMLC_DECLARE_PARAMETER(AFTParam)
Definition: survival_util.h:79
Parameter structure for AFT loss and metric.
Definition: survival_util.h:74
XGBOOST_DEVICE double GetLimitGradAtInfPred(CensoringType censor_type, bool sign, double sigma)
DECLARE_FIELD_ENUM_CLASS(xgboost::common::ProbabilityDistributionType)
constexpr double kMaxHessian
Definition: survival_util.h:45
XGBOOST_DEVICE double GetLimitHessAtInfPred(CensoringType censor_type, bool sign, double sigma)
XGBOOST_DEVICE double GetLimitHessAtInfPred< NormalDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:258
XGBOOST_DEVICE double Clip(double x, double x_min, double x_max)
Definition: survival_util.h:53
constexpr double kMinGradient
Definition: survival_util.h:42
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:84
CensoringType
Definition: survival_util.h:35
namespace of xgboost
Definition: base.h:102
Implementation of a few useful probability distributions.
XGBOOST_DEVICE double GetLimitGradAtInfPred< NormalDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:242
XGBOOST_DEVICE double GetLimitGradAtInfPred< ExtremeDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:303
static XGBOOST_DEVICE double Hessian(double y_lower, double y_upper, double y_pred, double sigma)
Definition: survival_util.h:180
macro for using C++11 enum class as DMLC parameter
ProbabilityDistributionType aft_loss_distribution
Choice of probability distribution for the noise term in AFT.
Definition: survival_util.h:76