Actual source code: curand2.cu
1: #include <petsc/private/randomimpl.h>
2: #include <thrust/transform.h>
3: #include <thrust/device_ptr.h>
4: #include <thrust/iterator/counting_iterator.h>
6: #if defined(PETSC_USE_COMPLEX)
7: struct complexscalelw : public thrust::unary_function<thrust::tuple<PetscReal, size_t>, PetscReal> {
8: PetscReal rl, rw;
9: PetscReal il, iw;
11: complexscalelw(PetscScalar low, PetscScalar width)
12: {
13: rl = PetscRealPart(low);
14: il = PetscImaginaryPart(low);
15: rw = PetscRealPart(width);
16: iw = PetscImaginaryPart(width);
17: }
19: __host__ __device__ PetscReal operator()(thrust::tuple<PetscReal, size_t> x) { return x.get<1>() % 2 ? x.get<0>() * iw + il : x.get<0>() * rw + rl; }
20: };
21: #endif
23: struct realscalelw : public thrust::unary_function<PetscReal, PetscReal> {
24: PetscReal l, w;
26: realscalelw(PetscReal low, PetscReal width) : l(low), w(width) { }
28: __host__ __device__ PetscReal operator()(PetscReal x) { return x * w + l; }
29: };
31: PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom r, size_t n, PetscReal *val, PetscBool isneg)
32: {
33: if (!r->iset) return 0;
34: if (isneg) { /* complex case, need to scale differently */
35: #if defined(PETSC_USE_COMPLEX)
36: thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val);
37: auto zibit = thrust::make_zip_iterator(thrust::make_tuple(pval, thrust::counting_iterator<size_t>(0)));
38: thrust::transform(zibit, zibit + n, pval, complexscalelw(r->low, r->width));
39: #else
40: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Negative array size %" PetscInt_FMT, (PetscInt)n);
41: #endif
42: } else {
43: PetscReal rl = PetscRealPart(r->low);
44: PetscReal rw = PetscRealPart(r->width);
45: thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val);
46: thrust::transform(pval, pval + n, pval, realscalelw(rl, rw));
47: }
48: return 0;
49: }