Actual source code: cupmthrustutility.hpp
1: #ifndef PETSC_CUPM_THRUST_UTILITY_HPP
2: #define PETSC_CUPM_THRUST_UTILITY_HPP
4: #include <petsc/private/deviceimpl.h>
5: #include <petsc/private/cupminterface.hpp>
7: #if defined(__cplusplus)
8: #include <thrust/device_ptr.h>
9: #include <thrust/transform.h>
11: namespace Petsc
12: {
14: namespace device
15: {
17: namespace cupm
18: {
20: namespace impl
21: {
23: #if PetscDefined(USING_NVCC)
24: #if !defined(THRUST_VERSION)
25: #error "THRUST_VERSION not defined!"
26: #endif
27: #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600)
28: #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__)
29: #else
30: #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__)
31: #endif
32: #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync
33: #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__)
34: #else
35: #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__)
36: #endif
38: namespace detail
39: {
41: struct PetscLogGpuTimer {
42: PetscLogGpuTimer() noexcept { PETSC_COMM_SELF, PetscLogGpuTimeBegin(); }
43: ~PetscLogGpuTimer() noexcept { PETSC_COMM_SELF, PetscLogGpuTimeEnd(); }
44: };
46: struct private_tag { };
48: } // namespace detail
50: #define THRUST_CALL(...) \
51: [&] { \
52: const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \
53: return thrust_call_par_on(__VA_ARGS__); \
54: }()
56: #define PetscCallThrust(...) \
57: do { \
58: try { \
59: __VA_ARGS__; \
60: } catch (const thrust::system_error &ex) { \
61: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); \
62: } \
63: } while (0)
65: template <typename T, typename BinaryOperator>
66: struct shift_operator {
67: const T *const s;
68: const BinaryOperator op;
70: PETSC_HOSTDEVICE_DECL PETSC_FORCEINLINE auto operator()(T x) const PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(op(std::move(x), *s))
71: };
73: template <typename T, typename BinaryOperator>
74: static inline auto make_shift_operator(T *s, BinaryOperator &&op) PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(shift_operator<T, BinaryOperator>{s, std::forward<BinaryOperator>(op)});
78: // actual implementation that calls thrust, 2 argument version
79: template <DeviceType DT, typename FunctorType, typename T>
80: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr))
81: {
82: const auto xptr = thrust::device_pointer_cast(xinout);
83: const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr;
86: THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward<FunctorType>(functor));
87: return 0;
88: }
90: // actual implementation that calls thrust, 3 argument version
91: template <DeviceType DT, typename FunctorType, typename T>
92: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin))
93: {
94: const auto xptr = thrust::device_pointer_cast(xin);
99: THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward<FunctorType>(functor));
100: return 0;
101: }
103: // one last intermediate function to check n, and log flops for everything
104: template <DeviceType DT, typename F, typename... T>
105: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface<DT>::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest))
106: {
107: PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n);
108: if (PetscLikely(n)) {
109: ThrustApplyPointwise<DT>(detail::private_tag{}, stream, std::forward<F>(functor), n, std::forward<T>(rest)...);
110: PetscLogGpuFlops(n);
111: }
112: return 0;
113: }
115: // serves as setup to the real implementation above
116: template <DeviceType T, typename F, typename... Args>
117: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest))
118: {
119: typename Interface<T>::cupmStream_t stream;
121: static_assert(sizeof...(Args) <= 3, "");
123: PetscDeviceContextGetStreamHandle_Internal(dctx, &stream);
124: ThrustApplyPointwise<T>(stream, std::forward<F>(functor), n, std::forward<Args>(rest)...);
125: return 0;
126: }
128: #define PetscCallCUPM_(...) \
129: do { \
130: using interface = Interface<DT>; \
131: using cupmError_t = typename interface::cupmError_t; \
132: const auto cupmName = []() { return interface::cupmName(); }; \
133: const auto cupmGetErrorName = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \
134: const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \
135: const auto cupmSuccess = interface::cupmSuccess; \
136: __VA_ARGS__; \
137: } while (0)
139: template <DeviceType DT, typename T>
140: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface<DT>::cupmStream_t stream, PetscInt n, T *ptr, const T *val))
141: {
143: if (n) {
144: const auto size = n * sizeof(T);
147: if (*val == T{0}) {
148: Interface<DT>::cupmMemsetAsync(ptr, 0, size, stream);
149: } else {
150: const auto xptr = thrust::device_pointer_cast(ptr);
152: THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val);
153: if (std::is_same<util::remove_cv_t<T>, PetscScalar>::value) {
154: PetscLogCpuToGpuScalar(size);
155: } else {
156: PetscLogCpuToGpu(size);
157: }
158: }
159: }
160: return 0;
161: }
163: #undef PetscCallCUPM_
166: template <DeviceType DT, typename T>
167: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val))
168: {
169: typename Interface<DT>::cupmStream_t stream;
172: PetscDeviceContextGetStreamHandle_Internal(dctx, &stream);
173: ThrustSet(stream, n, ptr, val);
174: return 0;
175: }
177: } // namespace impl
179: } // namespace cupm
181: } // namespace device
183: } // namespace Petsc
185: #endif // __cplusplus
187: #endif // PETSC_CUPM_THRUST_UTILITY_HPP