Actual source code: cupmthrustutility.hpp

  1: #ifndef PETSC_CUPM_THRUST_UTILITY_HPP
  2: #define PETSC_CUPM_THRUST_UTILITY_HPP

  4: #include <petsc/private/deviceimpl.h>
  5: #include <petsc/private/cupminterface.hpp>

  7: #if defined(__cplusplus)
  8:   #include <thrust/device_ptr.h>
  9:   #include <thrust/transform.h>

 11: namespace Petsc
 12: {

 14: namespace device
 15: {

 17: namespace cupm
 18: {

 20: namespace impl
 21: {

 23:   #if PetscDefined(USING_NVCC)
 24:     #if !defined(THRUST_VERSION)
 25:       #error "THRUST_VERSION not defined!"
 26:     #endif
 27:     #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600)
 28:       #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__)
 29:     #else
 30:       #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__)
 31:     #endif
 32:   #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync
 33:     #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__)
 34:   #else
 35:     #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__)
 36:   #endif

 38: namespace detail
 39: {

 41: struct PetscLogGpuTimer {
 42:   PetscLogGpuTimer() noexcept { PETSC_COMM_SELF, PetscLogGpuTimeBegin(); }
 43:   ~PetscLogGpuTimer() noexcept { PETSC_COMM_SELF, PetscLogGpuTimeEnd(); }
 44: };

 46: struct private_tag { };

 48: } // namespace detail

 50:   #define THRUST_CALL(...) \
 51:     [&] { \
 52:       const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \
 53:       return thrust_call_par_on(__VA_ARGS__); \
 54:     }()

 56:   #define PetscCallThrust(...) \
 57:     do { \
 58:       try { \
 59:         __VA_ARGS__; \
 60:       } catch (const thrust::system_error &ex) { \
 61:         SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); \
 62:       } \
 63:     } while (0)

 65: template <typename T, typename BinaryOperator>
 66: struct shift_operator {
 67:   const T *const       s;
 68:   const BinaryOperator op;

 70:   PETSC_HOSTDEVICE_DECL PETSC_FORCEINLINE auto operator()(T x) const PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(op(std::move(x), *s))
 71: };

 73: template <typename T, typename BinaryOperator>
 74: static inline auto make_shift_operator(T *s, BinaryOperator &&op) PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(shift_operator<T, BinaryOperator>{s, std::forward<BinaryOperator>(op)});


 78: // actual implementation that calls thrust, 2 argument version
 79: template <DeviceType DT, typename FunctorType, typename T>
 80: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr))
 81: {
 82:   const auto xptr   = thrust::device_pointer_cast(xinout);
 83:   const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr;

 86:   THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward<FunctorType>(functor));
 87:   return 0;
 88: }

 90: // actual implementation that calls thrust, 3 argument version
 91: template <DeviceType DT, typename FunctorType, typename T>
 92: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin))
 93: {
 94:   const auto xptr = thrust::device_pointer_cast(xin);

 99:   THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward<FunctorType>(functor));
100:   return 0;
101: }

103: // one last intermediate function to check n, and log flops for everything
104: template <DeviceType DT, typename F, typename... T>
105: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface<DT>::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest))
106: {
107:   PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n);
108:   if (PetscLikely(n)) {
109:     ThrustApplyPointwise<DT>(detail::private_tag{}, stream, std::forward<F>(functor), n, std::forward<T>(rest)...);
110:     PetscLogGpuFlops(n);
111:   }
112:   return 0;
113: }

115: // serves as setup to the real implementation above
116: template <DeviceType T, typename F, typename... Args>
117: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest))
118: {
119:   typename Interface<T>::cupmStream_t stream;

121:   static_assert(sizeof...(Args) <= 3, "");
123:   PetscDeviceContextGetStreamHandle_Internal(dctx, &stream);
124:   ThrustApplyPointwise<T>(stream, std::forward<F>(functor), n, std::forward<Args>(rest)...);
125:   return 0;
126: }

128:   #define PetscCallCUPM_(...) \
129:     do { \
130:       using interface               = Interface<DT>; \
131:       using cupmError_t             = typename interface::cupmError_t; \
132:       const auto cupmName           = []() { return interface::cupmName(); }; \
133:       const auto cupmGetErrorName   = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \
134:       const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \
135:       const auto cupmSuccess        = interface::cupmSuccess; \
136:       __VA_ARGS__; \
137:     } while (0)

139: template <DeviceType DT, typename T>
140: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface<DT>::cupmStream_t stream, PetscInt n, T *ptr, const T *val))
141: {
143:   if (n) {
144:     const auto size = n * sizeof(T);

147:     if (*val == T{0}) {
148:       Interface<DT>::cupmMemsetAsync(ptr, 0, size, stream);
149:     } else {
150:       const auto xptr = thrust::device_pointer_cast(ptr);

152:       THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val);
153:       if (std::is_same<util::remove_cv_t<T>, PetscScalar>::value) {
154:         PetscLogCpuToGpuScalar(size);
155:       } else {
156:         PetscLogCpuToGpu(size);
157:       }
158:     }
159:   }
160:   return 0;
161: }

163:   #undef PetscCallCUPM_

166: template <DeviceType DT, typename T>
167: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val))
168: {
169:   typename Interface<DT>::cupmStream_t stream;

172:   PetscDeviceContextGetStreamHandle_Internal(dctx, &stream);
173:   ThrustSet(stream, n, ptr, val);
174:   return 0;
175: }

177: } // namespace impl

179: } // namespace cupm

181: } // namespace device

183: } // namespace Petsc

185: #endif // __cplusplus

187: #endif // PETSC_CUPM_THRUST_UTILITY_HPP