Intrepid2
Intrepid2_DirectSumBasis.hpp
Go to the documentation of this file.
1// @HEADER
2// ************************************************************************
3//
4// Intrepid2 Package
5// Copyright (2007) Sandia Corporation
6//
7// Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
8// license for use of this work by or on behalf of the U.S. Government.
9//
10// Redistribution and use in source and binary forms, with or without
11// modification, are permitted provided that the following conditions are
12// met:
13//
14// 1. Redistributions of source code must retain the above copyright
15// notice, this list of conditions and the following disclaimer.
16//
17// 2. Redistributions in binary form must reproduce the above copyright
18// notice, this list of conditions and the following disclaimer in the
19// documentation and/or other materials provided with the distribution.
20//
21// 3. Neither the name of the Corporation nor the names of the
22// contributors may be used to endorse or promote products derived from
23// this software without specific prior written permission.
24//
25// THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
26// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
28// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
29// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
30// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
32// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36//
37// Questions? Contact Kyungjoo Kim (kyukim@sandia.gov),
38// Mauro Perego (mperego@sandia.gov), or
39// Nate Roberts (nvrober@sandia.gov)
40//
41// ************************************************************************
42// @HEADER
43
49#ifndef Intrepid2_DirectSumBasis_h
50#define Intrepid2_DirectSumBasis_h
51
52#include <Kokkos_View.hpp>
53#include <Kokkos_DynRankView.hpp>
54
55namespace Intrepid2
56{
67 template<typename BasisBaseClass>
68 class Basis_DirectSumBasis : public BasisBaseClass
69 {
70 public:
71 using BasisBase = BasisBaseClass;
72 using BasisPtr = Teuchos::RCP<BasisBase>;
73
74 using DeviceType = typename BasisBase::DeviceType;
75 using ExecutionSpace = typename BasisBase::ExecutionSpace;
76 using OutputValueType = typename BasisBase::OutputValueType;
77 using PointValueType = typename BasisBase::PointValueType;
78
79 using OrdinalTypeArray1DHost = typename BasisBase::OrdinalTypeArray1DHost;
80 using OrdinalTypeArray2DHost = typename BasisBase::OrdinalTypeArray2DHost;
81 using OutputViewType = typename BasisBase::OutputViewType;
82 using PointViewType = typename BasisBase::PointViewType;
83 using ScalarViewType = typename BasisBase::ScalarViewType;
84 protected:
85 BasisPtr basis1_;
86 BasisPtr basis2_;
87
88 std::string name_;
89 public:
94 Basis_DirectSumBasis(BasisPtr basis1, BasisPtr basis2)
95 :
96 basis1_(basis1),basis2_(basis2)
97 {
98 INTREPID2_TEST_FOR_EXCEPTION(basis1->getBasisType() != basis2->getBasisType(), std::invalid_argument, "basis1 and basis2 must agree in basis type");
99 INTREPID2_TEST_FOR_EXCEPTION(basis1->getBaseCellTopology().getKey() != basis2->getBaseCellTopology().getKey(),
100 std::invalid_argument, "basis1 and basis2 must agree in cell topology");
101 INTREPID2_TEST_FOR_EXCEPTION(basis1->getCoordinateSystem() != basis2->getCoordinateSystem(),
102 std::invalid_argument, "basis1 and basis2 must agree in coordinate system");
103
104 this->basisCardinality_ = basis1->getCardinality() + basis2->getCardinality();
105 this->basisDegree_ = std::max(basis1->getDegree(), basis2->getDegree());
106
107 {
108 std::ostringstream basisName;
109 basisName << basis1->getName() << " + " << basis2->getName();
110 name_ = basisName.str();
111 }
112
113 this->basisCellTopology_ = basis1->getBaseCellTopology();
114 this->basisType_ = basis1->getBasisType();
115 this->basisCoordinates_ = basis1->getCoordinateSystem();
116
117 if (this->basisType_ == BASIS_FEM_HIERARCHICAL)
118 {
119 int degreeLength = basis1_->getPolynomialDegreeLength();
120 INTREPID2_TEST_FOR_EXCEPTION(degreeLength != basis2_->getPolynomialDegreeLength(), std::invalid_argument, "Basis1 and Basis2 must agree on polynomial degree length");
121
122 this->fieldOrdinalPolynomialDegree_ = OrdinalTypeArray2DHost("DirectSumBasis degree lookup",this->basisCardinality_,degreeLength);
123 // our field ordinals start with basis1_; basis2_ follows
124 for (int fieldOrdinal1=0; fieldOrdinal1<basis1_->getCardinality(); fieldOrdinal1++)
125 {
126 int fieldOrdinal = fieldOrdinal1;
127 auto polynomialDegree = basis1->getPolynomialDegreeOfField(fieldOrdinal1);
128 for (int d=0; d<degreeLength; d++)
129 {
130 this->fieldOrdinalPolynomialDegree_(fieldOrdinal,d) = polynomialDegree(d);
131 }
132 }
133 for (int fieldOrdinal2=0; fieldOrdinal2<basis2_->getCardinality(); fieldOrdinal2++)
134 {
135 int fieldOrdinal = basis1->getCardinality() + fieldOrdinal2;
136
137 auto polynomialDegree = basis2->getPolynomialDegreeOfField(fieldOrdinal2);
138 for (int d=0; d<degreeLength; d++)
139 {
140 this->fieldOrdinalPolynomialDegree_(fieldOrdinal,d) = polynomialDegree(d);
141 }
142 }
143 }
144
145 // initialize tags
146 {
147 const auto & cardinality = this->basisCardinality_;
148
149 // Basis-dependent initializations
150 const ordinal_type tagSize = 4; // size of DoF tag, i.e., number of fields in the tag
151 const ordinal_type posScDim = 0; // position in the tag, counting from 0, of the subcell dim
152 const ordinal_type posScOrd = 1; // position in the tag, counting from 0, of the subcell ordinal
153 const ordinal_type posDfOrd = 2; // position in the tag, counting from 0, of DoF ordinal relative to the subcell
154
155 OrdinalTypeArray1DHost tagView("tag view", cardinality*tagSize);
156
157 shards::CellTopology cellTopo = this->basisCellTopology_;
158
159 unsigned spaceDim = cellTopo.getDimension();
160
161 ordinal_type basis2Offset = basis1_->getCardinality();
162
163 for (unsigned d=0; d<=spaceDim; d++)
164 {
165 unsigned subcellCount = cellTopo.getSubcellCount(d);
166 for (unsigned subcellOrdinal=0; subcellOrdinal<subcellCount; subcellOrdinal++)
167 {
168 ordinal_type subcellDofCount1 = basis1->getDofCount(d, subcellOrdinal);
169 ordinal_type subcellDofCount2 = basis2->getDofCount(d, subcellOrdinal);
170
171 ordinal_type subcellDofCount = subcellDofCount1 + subcellDofCount2;
172 for (ordinal_type localDofID=0; localDofID<subcellDofCount; localDofID++)
173 {
174 ordinal_type fieldOrdinal;
175 if (localDofID < subcellDofCount1)
176 {
177 // first basis: field ordinal matches the basis1 ordinal
178 fieldOrdinal = basis1_->getDofOrdinal(d, subcellOrdinal, localDofID);
179 }
180 else
181 {
182 // second basis: field ordinal is offset by basis1 cardinality
183 fieldOrdinal = basis2Offset + basis2_->getDofOrdinal(d, subcellOrdinal, localDofID - subcellDofCount1);
184 }
185 tagView(fieldOrdinal*tagSize+0) = d; // subcell dimension
186 tagView(fieldOrdinal*tagSize+1) = subcellOrdinal;
187 tagView(fieldOrdinal*tagSize+2) = localDofID;
188 tagView(fieldOrdinal*tagSize+3) = subcellDofCount;
189 }
190 }
191 }
192 // // Basis-independent function sets tag and enum data in tagToOrdinal_ and ordinalToTag_ arrays:
193 // // tags are constructed on host
194 this->setOrdinalTagData(this->tagToOrdinal_,
195 this->ordinalToTag_,
196 tagView,
197 this->basisCardinality_,
198 tagSize,
199 posScDim,
200 posScOrd,
201 posDfOrd);
202 }
203 }
204
210 virtual BasisValues<OutputValueType,DeviceType> allocateBasisValues( TensorPoints<PointValueType,DeviceType> points, const EOperator operatorType = OPERATOR_VALUE) const override
211 {
212 BasisValues<OutputValueType,DeviceType> basisValues1 = basis1_->allocateBasisValues(points, operatorType);
213 BasisValues<OutputValueType,DeviceType> basisValues2 = basis2_->allocateBasisValues(points, operatorType);
214
215 const int numScalarFamilies1 = basisValues1.numTensorDataFamilies();
216 if (numScalarFamilies1 > 0)
217 {
218 // then both basis1 and basis2 should be scalar-valued; check that for basis2:
219 const int numScalarFamilies2 = basisValues2.numTensorDataFamilies();
220 INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() <=0, std::invalid_argument, "When basis1 has scalar value, basis2 must also");
221 std::vector< TensorData<OutputValueType,DeviceType> > scalarFamilies(numScalarFamilies1 + numScalarFamilies2);
222 for (int i=0; i<numScalarFamilies1; i++)
223 {
224 scalarFamilies[i] = basisValues1.tensorData(i);
225 }
226 for (int i=0; i<numScalarFamilies2; i++)
227 {
228 scalarFamilies[i+numScalarFamilies1] = basisValues2.tensorData(i);
229 }
230 return BasisValues<OutputValueType,DeviceType>(scalarFamilies);
231 }
232 else
233 {
234 // then both basis1 and basis2 should be vector-valued; check that:
235 INTREPID2_TEST_FOR_EXCEPTION(!basisValues1.vectorData().isValid(), std::invalid_argument, "When basis1 does not have tensorData() defined, it must have a valid vectorData()");
236 INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() > 0, std::invalid_argument, "When basis1 has vector value, basis2 must also");
237
238 const auto & vectorData1 = basisValues1.vectorData();
239 const auto & vectorData2 = basisValues2.vectorData();
240
241 const int numFamilies1 = vectorData1.numFamilies();
242 const int numComponents = vectorData1.numComponents();
243 INTREPID2_TEST_FOR_EXCEPTION(numComponents != vectorData2.numComponents(), std::invalid_argument, "basis1 and basis2 must agree on the number of components in each vector");
244 const int numFamilies2 = vectorData2.numFamilies();
245
246 const int numFamilies = numFamilies1 + numFamilies2;
247 std::vector< std::vector<TensorData<OutputValueType,DeviceType> > > vectorComponents(numFamilies, std::vector<TensorData<OutputValueType,DeviceType> >(numComponents));
248
249 for (int i=0; i<numFamilies1; i++)
250 {
251 for (int j=0; j<numComponents; j++)
252 {
253 vectorComponents[i][j] = vectorData1.getComponent(i,j);
254 }
255 }
256 for (int i=0; i<numFamilies2; i++)
257 {
258 for (int j=0; j<numComponents; j++)
259 {
260 vectorComponents[i+numFamilies1][j] = vectorData2.getComponent(i,j);
261 }
262 }
263 VectorData<OutputValueType,DeviceType> vectorData(vectorComponents);
265 }
266 }
267
276 virtual void getDofCoords( ScalarViewType dofCoords ) const override {
277 const int basisCardinality1 = basis1_->getCardinality();
278 const int basisCardinality2 = basis2_->getCardinality();
279 const int basisCardinality = basisCardinality1 + basisCardinality2;
280
281 auto dofCoords1 = Kokkos::subview(dofCoords, std::make_pair(0,basisCardinality1), Kokkos::ALL());
282 auto dofCoords2 = Kokkos::subview(dofCoords, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
283
284 basis1_->getDofCoords(dofCoords1);
285 basis2_->getDofCoords(dofCoords2);
286 }
287
299 virtual void getDofCoeffs( ScalarViewType dofCoeffs ) const override {
300 const int basisCardinality1 = basis1_->getCardinality();
301 const int basisCardinality2 = basis2_->getCardinality();
302 const int basisCardinality = basisCardinality1 + basisCardinality2;
303
304 auto dofCoeffs1 = Kokkos::subview(dofCoeffs, std::make_pair(0,basisCardinality1), Kokkos::ALL());
305 auto dofCoeffs2 = Kokkos::subview(dofCoeffs, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
306
307 basis1_->getDofCoeffs(dofCoeffs1);
308 basis2_->getDofCoeffs(dofCoeffs2);
309 }
310
311
316 virtual
317 const char*
318 getName() const override {
319 return name_.c_str();
320 }
321
322 // since the getValues() below only overrides the FEM variants, we specify that
323 // we use the base class's getValues(), which implements the FVD variant by throwing an exception.
324 // (It's an error to use the FVD variant on this basis.)
325 using BasisBase::getValues;
326
338 virtual
339 void
342 const EOperator operatorType = OPERATOR_VALUE ) const override
343 {
344 const int fieldStartOrdinal1 = 0;
345 const int numFields1 = basis1_->getCardinality();
346 const int fieldStartOrdinal2 = numFields1;
347 const int numFields2 = basis2_->getCardinality();
348
349 auto basisValues1 = outputValues.basisValuesForFields(fieldStartOrdinal1, numFields1);
350 auto basisValues2 = outputValues.basisValuesForFields(fieldStartOrdinal2, numFields2);
351
352 basis1_->getValues(basisValues1, inputPoints, operatorType);
353 basis2_->getValues(basisValues2, inputPoints, operatorType);
354 }
355
374 virtual void getValues( OutputViewType outputValues, const PointViewType inputPoints,
375 const EOperator operatorType = OPERATOR_VALUE ) const override
376 {
377 int cardinality1 = basis1_->getCardinality();
378 int cardinality2 = basis2_->getCardinality();
379
380 auto range1 = std::make_pair(0,cardinality1);
381 auto range2 = std::make_pair(cardinality1,cardinality1+cardinality2);
382 if (outputValues.rank() == 2) // F,P
383 {
384 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL());
385 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL());
386
387 basis1_->getValues(outputValues1, inputPoints, operatorType);
388 basis2_->getValues(outputValues2, inputPoints, operatorType);
389 }
390 else if (outputValues.rank() == 3) // F,P,D
391 {
392 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL());
393 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL());
394
395 basis1_->getValues(outputValues1, inputPoints, operatorType);
396 basis2_->getValues(outputValues2, inputPoints, operatorType);
397 }
398 else if (outputValues.rank() == 4) // F,P,D,D
399 {
400 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
401 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
402
403 basis1_->getValues(outputValues1, inputPoints, operatorType);
404 basis2_->getValues(outputValues2, inputPoints, operatorType);
405 }
406 else if (outputValues.rank() == 5) // F,P,D,D,D
407 {
408 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
409 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
410
411 basis1_->getValues(outputValues1, inputPoints, operatorType);
412 basis2_->getValues(outputValues2, inputPoints, operatorType);
413 }
414 else if (outputValues.rank() == 6) // F,P,D,D,D,D
415 {
416 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
417 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
418
419 basis1_->getValues(outputValues1, inputPoints, operatorType);
420 basis2_->getValues(outputValues2, inputPoints, operatorType);
421 }
422 else if (outputValues.rank() == 7) // F,P,D,D,D,D,D
423 {
424 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
425 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
426
427 basis1_->getValues(outputValues1, inputPoints, operatorType);
428 basis2_->getValues(outputValues2, inputPoints, operatorType);
429 }
430 else
431 {
432 INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Unsupported outputValues rank");
433 }
434 }
435 };
436} // end namespace Intrepid2
437
438#endif /* Intrepid2_DirectSumBasis_h */
The data containers in Intrepid2 that support sum factorization and other reduced-data optimizations ...
const VectorDataType & vectorData() const
VectorData accessor.
TensorDataType & tensorData()
TensorData accessor for single-family scalar data.
BasisValues< Scalar, ExecSpaceType > basisValuesForFields(const int &fieldStartOrdinal, const int &numFields)
field start and length must align with families in vectorData_ or tensorDataFamilies_ (whichever is v...
A basis that is the direct sum of two other bases.
virtual const char * getName() const override
Returns basis name.
virtual BasisValues< OutputValueType, DeviceType > allocateBasisValues(TensorPoints< PointValueType, DeviceType > points, const EOperator operatorType=OPERATOR_VALUE) const override
Allocate BasisValues container suitable for passing to the getValues() variant that takes a TensorPoi...
Basis_DirectSumBasis(BasisPtr basis1, BasisPtr basis2)
Constructor.
virtual void getValues(OutputViewType outputValues, const PointViewType inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell.
virtual void getDofCoords(ScalarViewType dofCoords) const override
Fills in spatial locations (coordinates) of degrees of freedom (nodes) on the reference cell.
virtual void getValues(BasisValues< OutputValueType, DeviceType > outputValues, const TensorPoints< PointValueType, DeviceType > inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell, using point and output value containers that allow pre...
virtual void getDofCoeffs(ScalarViewType dofCoeffs) const override
Fills in coefficients of degrees of freedom for Lagrangian basis on the reference cell.
View-like interface to tensor data; tensor components are stored separately and multiplied together a...
View-like interface to tensor points; point components are stored separately; the appropriate coordin...
Reference-space field values for a basis, designed to support typical vector-valued bases.
KOKKOS_INLINE_FUNCTION constexpr bool isValid() const
returns true for containers that have data; false for those that don't (e.g., those that have been co...
KOKKOS_INLINE_FUNCTION int numFamilies() const
returns the number of families