OASIS
Open Algebra Software
Loading...
Searching...
No Matches
BinaryExpression.hpp
Go to the documentation of this file.
1//
2// Created by Matthew McCall on 7/2/23.
3//
4
5#ifndef OASIS_BINARYEXPRESSION_HPP
6#define OASIS_BINARYEXPRESSION_HPP
7
8#include <algorithm>
9#include <cassert>
10#include <functional>
11#include <list>
12
13#include "Expression.hpp"
14#include "RecursiveCast.hpp"
15#include "Serialization.hpp"
16
17namespace Oasis {
24template <typename MostSigOpT, typename LeastSigOpT, typename T>
26
27template <template <typename, typename> typename T>
28concept IAssociativeAndCommutative = IExpression<T<Expression, Expression>> && ((T<Expression, Expression>::GetStaticCategory() & (Associative | Commutative)) == (Associative | Commutative));
29
36template <template <typename, typename> typename T>
39{
40 if (ops.size() <= 1) {
41 return nullptr;
42 }
43
44 using GeneralizedT = T<Expression, Expression>;
45
47 opsList.resize(ops.size());
48
49 std::transform(ops.begin(), ops.end(), opsList.begin(), [](const auto& op) { return op->Copy(); });
50
51 while (std::next(opsList.begin()) != opsList.end()) {
52 for (auto i = opsList.begin(); i != opsList.end() && std::next(i) != opsList.end();) {
53 auto node = std::make_unique<GeneralizedT>(**i, **std::next(i));
54 opsList.insert(i, std::move(node));
55 i = opsList.erase(i, std::next(i, 2));
56 }
57 }
58
59 auto* result = dynamic_cast<GeneralizedT*>(opsList.front().release());
60 return std::unique_ptr<GeneralizedT>(result);
61}
62
78template <template <IExpression, IExpression> class DerivedT, IExpression MostSigOpT = Expression, IExpression LeastSigOpT = MostSigOpT>
80
81 using DerivedSpecialized = DerivedT<MostSigOpT, LeastSigOpT>;
82 using DerivedGeneralized = DerivedT<Expression, Expression>;
83
84public:
85 BinaryExpression() = default;
87 {
88 if (other.HasMostSigOp()) {
90 }
91
92 if (other.HasLeastSigOp()) {
94 }
95 }
96
97 BinaryExpression(const MostSigOpT& mostSigOp, const LeastSigOpT& leastSigOp)
98 {
101 }
102
103 template <IExpression Op1T, IExpression Op2T, IExpression... OpsT>
104 BinaryExpression(const Op1T& op1, const Op2T& op2, const OpsT&... ops)
105 {
106 static_assert(IAssociativeAndCommutative<DerivedT>, "List initializer only supported for associative and commutative expressions");
107 static_assert(std::is_same_v<DerivedGeneralized, DerivedSpecialized>, "List initializer only supported for generalized expressions");
108
110
111 for (auto opWrapper : std::vector<std::reference_wrapper<const Expression>> { static_cast<const Expression&>(op1), static_cast<const Expression&>(op2), (static_cast<const Expression&>(ops))... }) {
112 const Expression& operand = opWrapper.get();
113 opsVec.emplace_back(operand.Copy());
114 }
115
116 // build expression from vector
117 auto generalized = BuildFromVector<DerivedT>(opsVec);
118
119 SetLeastSigOp(generalized->GetLeastSigOp());
120 SetMostSigOp(generalized->GetMostSigOp());
121 }
122
123 [[nodiscard]] auto Copy() const -> std::unique_ptr<Expression> final
124 {
125 return std::make_unique<DerivedSpecialized>(*static_cast<const DerivedSpecialized*>(this));
126 }
127
128 [[nodiscard]] auto Differentiate(const Expression& differentiationVariable) const -> std::unique_ptr<Expression> override
129 {
130 return Generalize()->Differentiate(differentiationVariable);
131 }
132 [[nodiscard]] auto Equals(const Expression& other) const -> bool final
133 {
134 if (this->GetType() != other.GetType()) {
135 return false;
136 }
137
138 const auto otherGeneralized = other.Generalize();
139 const auto& otherBinaryGeneralized = static_cast<const DerivedGeneralized&>(*otherGeneralized);
140
141 bool mostSigOpMismatch = false, leastSigOpMismatch = false;
142
143 if (this->HasMostSigOp() == otherBinaryGeneralized.HasMostSigOp()) {
144 if (mostSigOp && otherBinaryGeneralized.HasMostSigOp()) {
145 mostSigOpMismatch = !mostSigOp->Equals(otherBinaryGeneralized.GetMostSigOp());
146 }
147 } else {
148 mostSigOpMismatch = true;
149 }
150
151 if (this->HasLeastSigOp() == otherBinaryGeneralized.HasLeastSigOp()) {
152 if (this->HasLeastSigOp() && otherBinaryGeneralized.HasLeastSigOp()) {
153 leastSigOpMismatch = !leastSigOp->Equals(otherBinaryGeneralized.GetLeastSigOp());
154 }
155 } else {
156 mostSigOpMismatch = true;
157 }
158
159 if (!mostSigOpMismatch && !leastSigOpMismatch) {
160 return true;
161 }
162
163 if (!(this->GetCategory() & Associative)) {
164 return false;
165 }
166
167 auto thisFlattened = std::vector<std::unique_ptr<Expression>> {};
168 auto otherFlattened = std::vector<std::unique_ptr<Expression>> {};
169
170 this->Flatten(thisFlattened);
171 otherBinaryGeneralized.Flatten(otherFlattened);
172
173 for (const auto& thisOperand : thisFlattened) {
174 if (std::find_if(otherFlattened.begin(), otherFlattened.end(), [&thisOperand](const auto& otherOperand) {
175 return thisOperand->Equals(*otherOperand);
176 })
177 == otherFlattened.end()) {
178 return false;
179 }
180 }
181
182 return true;
183 }
184
185 [[nodiscard]] auto Generalize() const -> std::unique_ptr<Expression> final
186 {
187 DerivedGeneralized generalized;
188
189 if (this->mostSigOp) {
190 generalized.SetMostSigOp(*this->mostSigOp->Copy());
191 }
192
193 if (this->leastSigOp) {
194 generalized.SetLeastSigOp(*this->leastSigOp->Copy());
195 }
196
197 return std::make_unique<DerivedGeneralized>(generalized);
198 }
199
200 [[nodiscard]] auto Simplify() const -> std::unique_ptr<Expression> override
201 {
202 return Generalize()->Simplify();
203 }
204
205 [[nodiscard]] auto Integrate(const Expression& integrationVariable) const -> std::unique_ptr<Expression> override
206 {
207 return Generalize()->Integrate(integrationVariable);
208 }
209
210 [[nodiscard]] auto StructurallyEquivalent(const Expression& other) const -> bool final
211 {
212 if (this->GetType() != other.GetType()) {
213 return false;
214 }
215
216 const std::unique_ptr<Expression> otherGeneralized = other.Generalize();
217 const auto& otherBinaryGeneralized = static_cast<const DerivedGeneralized&>(*otherGeneralized);
218
219 if (this->HasMostSigOp() == otherBinaryGeneralized.HasMostSigOp()) {
220 if (this->HasMostSigOp() && otherBinaryGeneralized.HasMostSigOp()) {
221 if (!mostSigOp->StructurallyEquivalent(otherBinaryGeneralized.GetMostSigOp())) {
222 return false;
223 }
224 }
225 }
226
227 if (this->HasLeastSigOp() == otherBinaryGeneralized.HasLeastSigOp()) {
228 if (this->HasLeastSigOp() && otherBinaryGeneralized.HasLeastSigOp()) {
229 if (!leastSigOp->StructurallyEquivalent(otherBinaryGeneralized.GetLeastSigOp())) {
230 return false;
231 }
232 }
233 }
234
235 return true;
236 }
237
250 {
251 if (mostSigOp) {
252 if (this->mostSigOp->template Is<DerivedGeneralized>()) {
253 auto generalizedMostSigOp = this->mostSigOp->Generalize();
254 const auto& mostSigOp = static_cast<const DerivedGeneralized&>(*generalizedMostSigOp);
255 mostSigOp.Flatten(out);
256 } else {
257 out.push_back(this->mostSigOp->Copy());
258 }
259 }
260
261 if (leastSigOp) {
262 if (this->leastSigOp->template Is<DerivedGeneralized>()) {
263 auto generalizedLeastSigOp = this->leastSigOp->Generalize();
264 const auto& leastSigOp = static_cast<const DerivedGeneralized&>(*generalizedLeastSigOp);
265 leastSigOp.Flatten(out);
266 } else {
267 out.push_back(this->leastSigOp->Copy());
268 }
269 }
270 }
271
276 auto GetMostSigOp() const -> const MostSigOpT&
277 {
278 assert(mostSigOp != nullptr);
279 return *mostSigOp;
280 }
281
286 auto GetLeastSigOp() const -> const LeastSigOpT&
287 {
288 assert(leastSigOp != nullptr);
289 return *leastSigOp;
290 }
291
296 [[nodiscard]] auto HasMostSigOp() const -> bool
297 {
298 return mostSigOp != nullptr;
299 }
300
305 [[nodiscard]] auto HasLeastSigOp() const -> bool
306 {
307 return leastSigOp != nullptr;
308 }
309
314 template <typename T>
316 auto SetMostSigOp(const T& op) -> bool
317 {
319 this->mostSigOp = op.Copy();
320 return true;
321 }
322
325 return true;
326 }
327
328 if (auto castedOp = Oasis::RecursiveCast<MostSigOpT>(op); castedOp) {
329 mostSigOp = std::move(castedOp);
330 return true;
331 }
332
333 return false;
334 }
335
340 template <typename T>
342 auto SetLeastSigOp(const T& op) -> bool
343 {
345 this->leastSigOp = op.Copy();
346 return true;
347 }
348
351 return true;
352 }
353
354 if (auto castedOp = Oasis::RecursiveCast<LeastSigOpT>(op); castedOp) {
355 leastSigOp = std::move(castedOp);
356 return true;
357 }
358
359 return false;
360 }
361
362 auto Substitute(const Expression& var, const Expression& val) -> std::unique_ptr<Expression> override
363 {
365 std::unique_ptr<Expression> right = ((GetLeastSigOp().Copy())->Substitute(var, val));
366 DerivedT<Expression, Expression> comb = DerivedT<Expression, Expression> { *left, *right };
367 auto ret = comb.Simplify();
368 return ret;
369 }
374 auto SwapOperands() const -> DerivedT<LeastSigOpT, MostSigOpT>
375 {
376 return DerivedT { *this->leastSigOp, *this->mostSigOp };
377 }
378
379 auto operator=(const BinaryExpression& other) -> BinaryExpression& = default;
380
381 void Serialize(SerializationVisitor& visitor) const override
382 {
383 const auto generalized = Generalize();
384 const auto& derivedGeneralized = dynamic_cast<const DerivedGeneralized&>(*generalized);
385 visitor.Serialize(derivedGeneralized);
386 }
387
390};
391
392} // Oasis
393
394#endif // OASIS_BINARYEXPRESSION_HPP
T begin(T... args)
A binary expression.
Definition BinaryExpression.hpp:79
auto Simplify() const -> std::unique_ptr< Expression > override
Simplifies this expression.
Definition BinaryExpression.hpp:200
auto Copy() const -> std::unique_ptr< Expression > final
Copies this expression.
Definition BinaryExpression.hpp:123
BinaryExpression(const BinaryExpression &other)
Definition BinaryExpression.hpp:86
std::unique_ptr< LeastSigOpT > leastSigOp
Definition BinaryExpression.hpp:389
auto Differentiate(const Expression &differentiationVariable) const -> std::unique_ptr< Expression > override
Tries to differentiate this function.
Definition BinaryExpression.hpp:128
void Serialize(SerializationVisitor &visitor) const override
This function serializes the expression object.
Definition BinaryExpression.hpp:381
auto Substitute(const Expression &var, const Expression &val) -> std::unique_ptr< Expression > override
Definition BinaryExpression.hpp:362
auto GetLeastSigOp() const -> const LeastSigOpT &
Gets the least significant operand of this expression.
Definition BinaryExpression.hpp:286
auto operator=(const BinaryExpression &other) -> BinaryExpression &=default
auto SetLeastSigOp(const T &op) -> bool
Sets the least significant operand of this expression.
Definition BinaryExpression.hpp:342
auto Integrate(const Expression &integrationVariable) const -> std::unique_ptr< Expression > override
Attempts to integrate this expression using integration rules.
Definition BinaryExpression.hpp:205
auto Generalize() const -> std::unique_ptr< Expression > final
Converts this expression to a more general expression.
Definition BinaryExpression.hpp:185
auto GetMostSigOp() const -> const MostSigOpT &
Gets the most significant operand of this expression.
Definition BinaryExpression.hpp:276
std::unique_ptr< MostSigOpT > mostSigOp
Definition BinaryExpression.hpp:388
auto SwapOperands() const -> DerivedT< LeastSigOpT, MostSigOpT >
Swaps the operands of this expression.
Definition BinaryExpression.hpp:374
auto SetMostSigOp(const T &op) -> bool
Sets the most significant operand of this expression.
Definition BinaryExpression.hpp:316
auto StructurallyEquivalent(const Expression &other) const -> bool final
Checks whether this expression is structurally equivalent to another expression.
Definition BinaryExpression.hpp:210
auto Flatten(std::vector< std::unique_ptr< Expression > > &out) const -> void
Flattens this expression.
Definition BinaryExpression.hpp:249
auto HasLeastSigOp() const -> bool
Gets whether this expression has a least significant operand.
Definition BinaryExpression.hpp:305
BinaryExpression(const Op1T &op1, const Op2T &op2, const OpsT &... ops)
Definition BinaryExpression.hpp:104
auto Equals(const Expression &other) const -> bool final
Compares this expression to another expression for equality.
Definition BinaryExpression.hpp:132
BinaryExpression(const MostSigOpT &mostSigOp, const LeastSigOpT &leastSigOp)
Definition BinaryExpression.hpp:97
auto HasMostSigOp() const -> bool
Gets whether this expression has a most significant operand.
Definition BinaryExpression.hpp:296
An expression.
Definition Expression.hpp:57
virtual auto Copy() const -> std::unique_ptr< Expression >=0
Copies this expression.
virtual auto GetCategory() const -> uint32_t
Gets the category of this expression.
Definition Expression.cpp:212
virtual auto GetType() const -> ExpressionType
Gets the type of this expression.
Definition Expression.cpp:220
virtual auto Simplify() const -> std::unique_ptr< Expression >
Simplifies this expression.
Definition Expression.cpp:244
Definition Serialization.hpp:50
virtual void Serialize(const Real &real)=0
Definition BinaryExpression.hpp:28
An expression concept.
Definition Concepts.hpp:25
A concept for an operand of a binary expression.
Definition BinaryExpression.hpp:25
Checks if type T is same as any of the provided types in U.
Definition Concepts.hpp:48
T emplace_back(T... args)
T end(T... args)
T erase(T... args)
T find_if(T... args)
T front(T... args)
T insert(T... args)
T is_same_v
Definition Add.hpp:11
auto BuildFromVector(const std::vector< std::unique_ptr< Expression > > &ops) -> std::unique_ptr< T< Expression, Expression > >
Builds a reasonably balanced binary expression from a vector of operands.
Definition BinaryExpression.hpp:38
@ Commutative
Definition Expression.hpp:45
@ Associative
Definition Expression.hpp:44
T next(T... args)
T resize(T... args)
T transform(T... args)