DOLFIN
DOLFIN C++ interface
Loading...
Searching...
No Matches
SUNDIALSNVector.h
1// Copyright (C) 2017 Chris Hadjigeorgiou and Chris Richardson
2//
3// This file is part of DOLFIN.
4//
5// DOLFIN is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Lesser General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// DOLFIN is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Lesser General Public License for more details.
14//
15// You should have received a copy of the GNU Lesser General Public License
16// along with DOLFIN. If not, see <http://www.gnu.org/licenses/>.
17
18
19#ifndef __DOLFIN_N_VECTOR_H
20#define __DOLFIN_N_VECTOR_H
21
22#ifdef HAS_SUNDIALS
23
24#include <string>
25#include <utility>
26#include <memory>
27#include <dolfin/common/types.h>
28#include <sundials/sundials_nvector.h>
29#include "DefaultFactory.h"
30#include "GenericVector.h"
31#include "Vector.h"
32
33namespace dolfin
34{
39 {
40 public:
41
45 SUNDIALSNVector(MPI_Comm comm=MPI_COMM_WORLD)
46 {
47 DefaultFactory factory;
48 vector = factory.create_vector(comm);
49 }
50
56 SUNDIALSNVector(MPI_Comm comm, std::size_t N)
57 {
58 DefaultFactory factory;
59 vector = factory.create_vector(comm);
60 vector->init(N);
61 N_V = std::unique_ptr<_generic_N_Vector>(new _generic_N_Vector);
62 N_V->ops = &ops;
63 N_V->content = (void *)(this);
64 }
65
69 SUNDIALSNVector(const SUNDIALSNVector& x) : vector(x.vec()->copy()) {}
70
74 SUNDIALSNVector(const GenericVector& x) : vector(x.copy())
75 {
76 N_V = std::unique_ptr<_generic_N_Vector>(new _generic_N_Vector);
77 N_V->ops = &ops;
78 N_V->content = (void *)(this);
79 }
80
84 SUNDIALSNVector(std::shared_ptr<GenericVector> x) : vector(x)
85 {
86 N_V = std::unique_ptr<_generic_N_Vector>(new _generic_N_Vector);
87 N_V->ops = &ops;
88 N_V->content = (void *)(this);
89 }
90 //-----------------------------------------------------------------------------
91
95 N_Vector nvector() const
96 {
97 N_V->content = (void *)(this);
98 return N_V.get();
99 }
100
104 std::shared_ptr<GenericVector> vec() const
105 {
106 return vector;
107 }
108
111 { *vector = *x.vector; return *this; }
112
113 private:
114
115 //--- Implementation of N_Vector ops
116
117 // Get ID for custom SUNDIALSNVector implementation
118 static N_Vector_ID N_VGetVectorID(N_Vector nv)
119 {
120 dolfin_debug("N_VGetVectorID");
121 return SUNDIALS_NVEC_CUSTOM;
122 }
123
124 // Sets the components of the N_Vector z to be the absolute values of the
125 // components of the N_Vector x
126 static void N_VAbs(N_Vector x, N_Vector z)
127 {
128 dolfin_debug("N_VAbs");
129 auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
130 auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
131
132 *vz = *vx;
133 vz->abs();
134 }
135
137 static void N_VConst(double c, N_Vector z)
138 {
139 dolfin_debug("N_VConst");
140 auto v = static_cast<SUNDIALSNVector *>(z->content)->vec();
141 *v = c;
142 }
143
147 static N_Vector N_VClone(N_Vector z)
148 {
149 dolfin_debug("N_VClone");
150 auto vz = static_cast<const SUNDIALSNVector *>(z->content);
151
152 SUNDIALSNVector *new_vector = new SUNDIALSNVector(*vz);
153
154 _generic_N_Vector *V = new _generic_N_Vector;
155 V->ops = z->ops;
156 V->content = (void *)(new_vector);
157
158 return V;
159 }
160
163 static N_Vector N_VCloneEmpty(N_Vector x)
164 {
165 dolfin_debug("N_VCloneEmpty");
166 dolfin_not_implemented();
167 return NULL;
168 }
169
172 static void N_VDestroy(N_Vector z)
173 {
174 dolfin_debug("N_VDestroy");
175 delete (SUNDIALSNVector*)(z->content);
176 delete z;
177 }
178
181 static void N_VSpace(N_Vector x, long int *y, long int *z)
182 {
183 dolfin_debug("N_VSpace");
184 dolfin_not_implemented();
185 }
186
188 static double* N_VGetArrayPointer(N_Vector x)
189 {
190 dolfin_debug("N_VGetArrayPointer");
191 dolfin_not_implemented();
192 return NULL;
193 }
194
196 static void N_VSetArrayPointer(double* c,N_Vector x)
197 {
198 dolfin_debug("N_VSetArrayPointer");
199 dolfin_not_implemented();
200 }
201
204 static void N_VProd(N_Vector x, N_Vector y, N_Vector z)
205 {
206 dolfin_debug("N_VProd");
207 auto vx = static_cast<const SUNDIALSNVector*>(x->content)->vec();
208 auto vy = static_cast<const SUNDIALSNVector*>(y->content)->vec();
209 auto vz = static_cast<SUNDIALSNVector*>(z->content)->vec();
210
211 // Copy x to z
212 *vz = *vx;
213 // Multiply by y
214 *vz *= *vy;
215 }
216
219 static void N_VDiv(N_Vector x, N_Vector y, N_Vector z)
220 {
221 dolfin_debug("N_VDiv");
222
223 auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
224 auto vy = static_cast<const SUNDIALSNVector *>(y->content)->vec();
225 auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
226
227 std::vector<double> xdata;
228 vx->get_local(xdata);
229 std::vector<double> ydata;
230 vy->get_local(ydata);
231 for (unsigned int i = 0; i != xdata.size(); ++i)
232 xdata[i] /= ydata[i];
233
234 vz->set_local(xdata);
235 vz->apply("insert");
236
237 }
238
240 static void N_VScale(double c, N_Vector x, N_Vector z)
241 {
242 dolfin_debug("N_VScale");
243 auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
244 auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
245
246 // z = c*x
247 *vz = *vx;
248 *vz *= c;
249 }
250
253 static void N_VInv(N_Vector x, N_Vector z)
254 {
255 dolfin_debug("N_VInv");
256 auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
257 auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
258
259 // z = 1/x
260 std::vector<double> xvals;
261 vx->get_local(xvals);
262 for (auto &val : xvals)
263 val = 1.0/val;
264 vz->set_local(xvals);
265 vz->apply("insert");
266 }
267
270 static void N_VAddConst(N_Vector x, double c, N_Vector z)
271 {
272 dolfin_debug("N_VAddConst");
273 auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
274 auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
275
276 *vz = *vx;
277 *vz += c;
278 }
279
281 static double N_VDotProd(N_Vector x, N_Vector z)
282 {
283 dolfin_debug("N_VDotProd");
284 auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
285 auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
286
287 return vx->inner(*vz);
288 }
289
291 static double N_VMaxNorm(N_Vector x)
292 {
293 dolfin_debug("N_VMaxNorm");
294 auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
295 auto vy = vx->copy();
296 vy->abs();
297 return vy->max();
298 }
299
301 static double N_VMin(N_Vector x)
302 {
303 dolfin_debug("N_VMin");
304 return (static_cast<const SUNDIALSNVector *>(x->content)->vec())->min();
305 }
306
309 static void N_VLinearSum(double a, N_Vector x, double b, N_Vector y, N_Vector z)
310 {
311 dolfin_debug("N_VLinearSum");
312 auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
313 auto vy = static_cast<const SUNDIALSNVector *>(y->content)->vec();
314 auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
315
316 std::vector<double> xdata;
317 vx->get_local(xdata);
318 std::vector<double> ydata;
319 vy->get_local(ydata);
320
321 for (unsigned int i = 0; i != xdata.size(); ++i)
322 xdata[i] = a*xdata[i] + b*ydata[i];
323
324 vz->set_local(xdata);
325 vz->apply("insert");
326 }
327
330 static double N_VWrmsNorm(N_Vector x, N_Vector z)
331 {
332 dolfin_debug("N_VWrmsNorm");
333 auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
334 auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
335
336 auto y = vx->copy();
337 *y *= *vz;
338 return y->norm("l2")/std::sqrt(y->size());
339 }
340
344 static double N_VWrmsNormMask(N_Vector x, N_Vector y, N_Vector z)
345 {
346 dolfin_debug("N_VWrmsNormMask");
347 dolfin_not_implemented();
348 return 0.0;
349 }
350
353 static double N_VWl2Norm(N_Vector x, N_Vector z )
354 {
355 dolfin_debug("N_VWl2Norm");
356 dolfin_not_implemented();
357 return 0.0;
358 }
359
361 static double N_VL1Norm(N_Vector x )
362 {
363 dolfin_debug("N_VL1Norm");
364 dolfin_not_implemented();
365 return 0.0;
366 }
367
370 static void N_VCompare(double c, N_Vector x, N_Vector z)
371 {
372 dolfin_debug("N_VCompare");
373 auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
374 auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
375 std::vector<double> xvals;
376 vx->get_local(xvals);
377 for (auto &val : xvals)
378 val = (std::abs(val) >= c) ? 1.0 : 0.0;
379 vz->set_local(xvals);
380 vz->apply("insert");
381 }
382
385 static int N_VInvTest(N_Vector x, N_Vector z)
386 {
387 dolfin_debug("N_VInvTest");
388 int no_zero_found = true;
389 auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
390 auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
391
392 std::vector<double> xvals;
393 vx->get_local(xvals);
394 for (auto &val : xvals)
395 if(val != 0)
396 val = 1.0/val;
397 else
398 no_zero_found = false;
399 vz->set_local(xvals);
400
401 vz->apply("insert");
402
403 return no_zero_found;
404 }
405
406
409 static double N_VMinQuotient(N_Vector x, N_Vector z )
410 {
411 dolfin_debug("N_VConstrMask");
412 dolfin_not_implemented();
413 return 0.0;
414 }
415
417 static int N_VConstrMask(N_Vector x, N_Vector y, N_Vector z )
418 {
419 dolfin_debug("N_VConstrMask");
420 dolfin_not_implemented();
421 return 0;
422 }
423
424 // Pointer to concrete implementation
425 std::shared_ptr<GenericVector> vector;
426
427 // Pointer to SUNDIALS struct
428 std::unique_ptr<_generic_N_Vector> N_V;
429
430 // Structure containing function pointers to vector operations
431 struct _generic_N_Vector_Ops ops = {N_VGetVectorID, // N_Vector_ID (*N_VGetVectorID)(SUNDIALSNVector);
432 N_VClone, // NVector (*N_VClone)(NVector);
433 N_VCloneEmpty, // NVector (*N_VCloneEmpty)(NVector);
434 N_VDestroy, // void (*N_VDestroy)(NVector);
435 NULL, //N_VSpace, // void (*N_VSpace)(NVector, long int *, long int *);
436 N_VGetArrayPointer, // realtype* (*N_VGetArrayPointer)(NVector);
437 N_VSetArrayPointer, // void (*N_VSetArrayPointer)(realtype *, NVector);
438 N_VLinearSum, // void (*N_VLinearSum)(realtype, NVector, realtype, NVector, NVector);
439 N_VConst, // void (*N_VConst)(realtype, NVector);
440 N_VProd, // void (*N_VProd)(NVector, NVector, NVector);
441 N_VDiv, // void (*N_VDiv)(NVector, NVector, NVector);
442 N_VScale, // void (*N_VScale)(realtype, NVector, NVector);
443 N_VAbs, // void (*N_VAbs)(NVector, NVector);
444 N_VInv, // void (*N_VInv)(NVector, NVector);
445 N_VAddConst, // void (*N_VAddConst)(NVector, realtype, NVector);
446 N_VDotProd, // realtype (*N_VDotProd)(NVector, NVector);
447 N_VMaxNorm, // realtype (*N_VMaxNorm)(NVector);
448 N_VWrmsNorm, // realtype (*N_VWrmsNorm)(NVector, NVector);
449 N_VWrmsNormMask, // realtype (*N_VWrmsNormMask)(NVector, NVector, NVector);
450 N_VMin, // realtype (*N_VMin)(NVector);
451 N_VWl2Norm, // realtype (*N_VWl2Norm)(NVector, NVector);
452 N_VL1Norm, // realtype (*N_VL1Norm)(NVector);
453 N_VCompare, // void (*N_VCompare)(realtype, NVector, NVector);
454 N_VInvTest, // booleantype (*N_VInvtest)(NVector, NVector);
455 N_VConstrMask, // booleantype (*N_VConstrMask)(NVector, NVector, NVector);
456 N_VMinQuotient}; // realtype (*N_VMinQuotient)(NVector, NVector);
457 };
458
459
460}
461
462#endif
463
464#endif
Default linear algebra factory based on global parameter "linear_algebra_backend".
Definition DefaultFactory.h:36
virtual std::shared_ptr< GenericVector > create_vector(MPI_Comm comm) const
Create empty vector.
Definition DefaultFactory.cpp:37
This class defines a common interface for vectors.
Definition GenericVector.h:48
Definition SUNDIALSNVector.h:39
N_Vector nvector() const
Definition SUNDIALSNVector.h:95
SUNDIALSNVector(const SUNDIALSNVector &x)
Definition SUNDIALSNVector.h:69
SUNDIALSNVector(const GenericVector &x)
Definition SUNDIALSNVector.h:74
const SUNDIALSNVector & operator=(const SUNDIALSNVector &x)
Assignment operator.
Definition SUNDIALSNVector.h:110
SUNDIALSNVector(MPI_Comm comm, std::size_t N)
Definition SUNDIALSNVector.h:56
SUNDIALSNVector(MPI_Comm comm=MPI_COMM_WORLD)
Definition SUNDIALSNVector.h:45
SUNDIALSNVector(std::shared_ptr< GenericVector > x)
Definition SUNDIALSNVector.h:84
std::shared_ptr< GenericVector > vec() const
Definition SUNDIALSNVector.h:104
Definition adapt.h:30