OASIS
Open Algebra Software
Loading...
Searching...
No Matches
BinaryExpression.hpp
Go to the documentation of this file.
1//
2// Created by Matthew McCall on 7/2/23.
3//
4
5#ifndef OASIS_BINARYEXPRESSION_HPP
6#define OASIS_BINARYEXPRESSION_HPP
7
8#include <algorithm>
9#include <cassert>
10#include <functional>
11#include <list>
12
13#include "Expression.hpp"
15#include "RecursiveCast.hpp"
16#include "Visit.hpp"
17
18namespace Oasis {
28template <typename MostSigOpT, typename LeastSigOpT, typename T>
30
31template <template <typename, typename> typename T>
32concept IAssociativeAndCommutative = IExpression<T<Expression, Expression>> && ((T<Expression, Expression>::GetStaticCategory() & (Associative | Commutative)) == (Associative | Commutative));
33
40template <template <typename, typename> typename T>
43{
44 if (ops.size() <= 1) {
45 return nullptr;
46 }
47
48 using GeneralizedT = T<Expression, Expression>;
49
51 opsList.resize(ops.size());
52
53 std::transform(ops.begin(), ops.end(), opsList.begin(), [](const auto& op) { return op->Copy(); });
54
55 while (std::next(opsList.begin()) != opsList.end()) {
56 for (auto i = opsList.begin(); i != opsList.end() && std::next(i) != opsList.end();) {
57 auto node = std::make_unique<GeneralizedT>(**i, **std::next(i));
58 opsList.insert(i, std::move(node));
59 i = opsList.erase(i, std::next(i, 2));
60 }
61 }
62
63 auto* result = dynamic_cast<GeneralizedT*>(opsList.front().release());
64 return std::unique_ptr<GeneralizedT>(result);
65}
66
82template <template <IExpression, IExpression> class DerivedT, IExpression MostSigOpT = Expression, IExpression LeastSigOpT = MostSigOpT>
84
85 using DerivedSpecialized = DerivedT<MostSigOpT, LeastSigOpT>;
86 using DerivedGeneralized = DerivedT<Expression, Expression>;
87
88public:
89 BinaryExpression() = default;
91 {
92 if (other.HasMostSigOp()) {
94 }
95
96 if (other.HasLeastSigOp()) {
98 }
99 }
100
101 BinaryExpression(const MostSigOpT& mostSigOp, const LeastSigOpT& leastSigOp)
102 {
105 }
106
107 template <IExpression Op1T, IExpression Op2T, IExpression... OpsT>
108 BinaryExpression(const Op1T& op1, const Op2T& op2, const OpsT&... ops)
109 {
110 static_assert(IAssociativeAndCommutative<DerivedT>, "List initializer only supported for associative and commutative expressions");
111 static_assert(std::is_same_v<DerivedGeneralized, DerivedSpecialized>, "List initializer only supported for generalized expressions");
112
114
115 for (auto opWrapper : std::vector<std::reference_wrapper<const Expression>> { static_cast<const Expression&>(op1), static_cast<const Expression&>(op2), (static_cast<const Expression&>(ops))... }) {
116 const Expression& operand = opWrapper.get();
117 opsVec.emplace_back(operand.Copy());
118 }
119
120 // build expression from vector
121 auto generalized = BuildFromVector<DerivedT>(opsVec);
122
123 SetLeastSigOp(generalized->GetLeastSigOp());
124 SetMostSigOp(generalized->GetMostSigOp());
125 }
126
127 [[nodiscard]] auto Copy() const -> std::unique_ptr<Expression> final
128 {
129 return std::make_unique<DerivedSpecialized>(*static_cast<const DerivedSpecialized*>(this));
130 }
131
132 [[nodiscard]] auto Differentiate(const Expression& differentiationVariable) const -> std::unique_ptr<Expression> override
133 {
134 return Generalize()->Differentiate(differentiationVariable);
135 }
136 [[nodiscard]] auto Equals(const Expression& other) const -> bool final
137 {
138 if (this->GetType() != other.GetType()) {
139 return false;
140 }
141
142 const auto otherGeneralized = other.Generalize();
143 const auto& otherBinaryGeneralized = static_cast<const DerivedGeneralized&>(*otherGeneralized);
144
145 bool mostSigOpMismatch = false, leastSigOpMismatch = false;
146
147 if (this->HasMostSigOp() == otherBinaryGeneralized.HasMostSigOp()) {
148 if (mostSigOp && otherBinaryGeneralized.HasMostSigOp()) {
149 mostSigOpMismatch = !mostSigOp->Equals(otherBinaryGeneralized.GetMostSigOp());
150 }
151 } else {
152 mostSigOpMismatch = true;
153 }
154
155 if (this->HasLeastSigOp() == otherBinaryGeneralized.HasLeastSigOp()) {
156 if (this->HasLeastSigOp() && otherBinaryGeneralized.HasLeastSigOp()) {
157 leastSigOpMismatch = !leastSigOp->Equals(otherBinaryGeneralized.GetLeastSigOp());
158 }
159 } else {
160 mostSigOpMismatch = true;
161 }
162
163 if (!mostSigOpMismatch && !leastSigOpMismatch) {
164 return true;
165 }
166
167 if (!(this->GetCategory() & Associative)) {
168 return false;
169 }
170
171 auto thisFlattened = std::vector<std::unique_ptr<Expression>> {};
172 auto otherFlattened = std::vector<std::unique_ptr<Expression>> {};
173
174 this->Flatten(thisFlattened);
175 otherBinaryGeneralized.Flatten(otherFlattened);
176
177 for (const auto& thisOperand : thisFlattened) {
178 if (std::find_if(otherFlattened.begin(), otherFlattened.end(), [&thisOperand](const auto& otherOperand) {
179 return thisOperand->Equals(*otherOperand);
180 })
181 == otherFlattened.end()) {
182 return false;
183 }
184 }
185
186 return true;
187 }
188
189 [[nodiscard]] auto Generalize() const -> std::unique_ptr<Expression> final
190 {
191 DerivedGeneralized generalized;
192
193 if (this->mostSigOp) {
194 generalized.SetMostSigOp(*this->mostSigOp->Copy());
195 }
196
197 if (this->leastSigOp) {
198 generalized.SetLeastSigOp(*this->leastSigOp->Copy());
199 }
200
201 return std::make_unique<DerivedGeneralized>(generalized);
202 }
203
204 [[nodiscard]] auto Simplify() const -> std::unique_ptr<Expression> override
205 {
206 SimplifyVisitor simplifyVisitor {};
207 auto e = Generalize();
208 auto s = e->Accept(simplifyVisitor);
209 if (!s) {
210 return e;
211 }
212 return std::move(s).value();
213 }
214
215 [[nodiscard]] auto Integrate(const Expression& integrationVariable) const -> std::unique_ptr<Expression> override
216 {
217 return Generalize()->Integrate(integrationVariable);
218 }
219
220 [[nodiscard]] auto StructurallyEquivalent(const Expression& other) const -> bool final
221 {
222 if (this->GetType() != other.GetType()) {
223 return false;
224 }
225
226 const std::unique_ptr<Expression> otherGeneralized = other.Generalize();
227 const auto& otherBinaryGeneralized = static_cast<const DerivedGeneralized&>(*otherGeneralized);
228
229 if (this->HasMostSigOp() == otherBinaryGeneralized.HasMostSigOp()) {
230 if (this->HasMostSigOp() && otherBinaryGeneralized.HasMostSigOp()) {
231 if (!mostSigOp->StructurallyEquivalent(otherBinaryGeneralized.GetMostSigOp())) {
232 return false;
233 }
234 }
235 }
236
237 if (this->HasLeastSigOp() == otherBinaryGeneralized.HasLeastSigOp()) {
238 if (this->HasLeastSigOp() && otherBinaryGeneralized.HasLeastSigOp()) {
239 if (!leastSigOp->StructurallyEquivalent(otherBinaryGeneralized.GetLeastSigOp())) {
240 return false;
241 }
242 }
243 }
244
245 return true;
246 }
247
260 {
261 if (mostSigOp) {
262 if (this->mostSigOp->template Is<DerivedGeneralized>()) {
263 auto generalizedMostSigOp = this->mostSigOp->Generalize();
264 const auto& mostSigOp = static_cast<const DerivedGeneralized&>(*generalizedMostSigOp);
265 mostSigOp.Flatten(out);
266 } else {
267 out.push_back(this->mostSigOp->Copy());
268 }
269 }
270
271 if (leastSigOp) {
272 if (this->leastSigOp->template Is<DerivedGeneralized>()) {
273 auto generalizedLeastSigOp = this->leastSigOp->Generalize();
274 const auto& leastSigOp = static_cast<const DerivedGeneralized&>(*generalizedLeastSigOp);
275 leastSigOp.Flatten(out);
276 } else {
277 out.push_back(this->leastSigOp->Copy());
278 }
279 }
280 }
281
286 auto GetMostSigOp() const -> const MostSigOpT&
287 {
288 assert(mostSigOp != nullptr);
289 return *mostSigOp;
290 }
291
296 auto GetLeastSigOp() const -> const LeastSigOpT&
297 {
298 assert(leastSigOp != nullptr);
299 return *leastSigOp;
300 }
301
306 [[nodiscard]] auto HasMostSigOp() const -> bool
307 {
308 return mostSigOp != nullptr;
309 }
310
315 [[nodiscard]] auto HasLeastSigOp() const -> bool
316 {
317 return leastSigOp != nullptr;
318 }
319
324 template <typename T>
326 auto SetMostSigOp(const T& op) -> bool
327 {
329 this->mostSigOp = op.Copy();
330 return true;
331 }
332
335 return true;
336 }
337
338 if (auto castedOp = Oasis::RecursiveCast<MostSigOpT>(op); castedOp) {
339 mostSigOp = std::move(castedOp);
340 return true;
341 }
342
343 return false;
344 }
345
350 template <typename T>
352 auto SetLeastSigOp(const T& op) -> bool
353 {
355 this->leastSigOp = op.Copy();
356 return true;
357 }
358
361 return true;
362 }
363
364 if (auto castedOp = Oasis::RecursiveCast<LeastSigOpT>(op); castedOp) {
365 leastSigOp = std::move(castedOp);
366 return true;
367 }
368
369 return false;
370 }
371
372 auto Substitute(const Expression& var, const Expression& val) -> std::unique_ptr<Expression> override
373 {
374 // TODO: FIX WITH VISITOR?
376 std::unique_ptr<Expression> right = ((GetLeastSigOp().Copy())->Substitute(var, val));
377 DerivedT<Expression, Expression> comb = DerivedT<Expression, Expression> { *left, *right };
378
379 Oasis::SimplifyVisitor simplifyVisitor {};
380 auto simplified = comb.Accept(simplifyVisitor);
381 if (!simplified) {
382 return comb.Generalize();
383 }
384 return std::move(simplified.value());
385 }
390 auto SwapOperands() const -> DerivedT<LeastSigOpT, MostSigOpT>
391 {
392 return DerivedT { *this->leastSigOp, *this->mostSigOp };
393 }
394
395 auto operator=(const BinaryExpression& other) -> BinaryExpression& = default;
396
397 auto AcceptInternal(Visitor& visitor) const -> any override
398 {
399 const auto generalized = Generalize();
400 const auto& derivedGeneralized = dynamic_cast<const DerivedGeneralized&>(*generalized);
401 return visitor.Visit(derivedGeneralized);
402 }
403
406};
407
408} // Oasis
409
410#endif // OASIS_BINARYEXPRESSION_HPP
T begin(T... args)
A binary expression.
Definition BinaryExpression.hpp:83
auto Simplify() const -> std::unique_ptr< Expression > override
Simplifies this expression.
Definition BinaryExpression.hpp:204
auto Copy() const -> std::unique_ptr< Expression > final
Copies this expression.
Definition BinaryExpression.hpp:127
BinaryExpression(const BinaryExpression &other)
Definition BinaryExpression.hpp:90
std::unique_ptr< LeastSigOpT > leastSigOp
Definition BinaryExpression.hpp:405
auto Differentiate(const Expression &differentiationVariable) const -> std::unique_ptr< Expression > override
Tries to differentiate this function.
Definition BinaryExpression.hpp:132
auto AcceptInternal(Visitor &visitor) const -> any override
This function serializes the expression object.
Definition BinaryExpression.hpp:397
auto Substitute(const Expression &var, const Expression &val) -> std::unique_ptr< Expression > override
Definition BinaryExpression.hpp:372
auto GetLeastSigOp() const -> const LeastSigOpT &
Gets the least significant operand of this expression.
Definition BinaryExpression.hpp:296
auto operator=(const BinaryExpression &other) -> BinaryExpression &=default
auto SetLeastSigOp(const T &op) -> bool
Sets the least significant operand of this expression.
Definition BinaryExpression.hpp:352
auto Integrate(const Expression &integrationVariable) const -> std::unique_ptr< Expression > override
Attempts to integrate this expression using integration rules.
Definition BinaryExpression.hpp:215
auto Generalize() const -> std::unique_ptr< Expression > final
Converts this expression to a more general expression.
Definition BinaryExpression.hpp:189
auto GetMostSigOp() const -> const MostSigOpT &
Gets the most significant operand of this expression.
Definition BinaryExpression.hpp:286
std::unique_ptr< MostSigOpT > mostSigOp
Definition BinaryExpression.hpp:404
auto SwapOperands() const -> DerivedT< LeastSigOpT, MostSigOpT >
Swaps the operands of this expression.
Definition BinaryExpression.hpp:390
auto SetMostSigOp(const T &op) -> bool
Sets the most significant operand of this expression.
Definition BinaryExpression.hpp:326
auto StructurallyEquivalent(const Expression &other) const -> bool final
Checks whether this expression is structurally equivalent to another expression.
Definition BinaryExpression.hpp:220
auto Flatten(std::vector< std::unique_ptr< Expression > > &out) const -> void
Flattens this expression.
Definition BinaryExpression.hpp:259
auto HasLeastSigOp() const -> bool
Gets whether this expression has a least significant operand.
Definition BinaryExpression.hpp:315
BinaryExpression(const Op1T &op1, const Op2T &op2, const OpsT &... ops)
Definition BinaryExpression.hpp:108
auto Equals(const Expression &other) const -> bool final
Compares this expression to another expression for equality.
Definition BinaryExpression.hpp:136
BinaryExpression(const MostSigOpT &mostSigOp, const LeastSigOpT &leastSigOp)
Definition BinaryExpression.hpp:101
auto HasMostSigOp() const -> bool
Gets whether this expression has a most significant operand.
Definition BinaryExpression.hpp:306
An expression.
Definition Expression.hpp:63
virtual auto Copy() const -> std::unique_ptr< Expression >=0
Copies this expression.
virtual auto GetCategory() const -> uint32_t
Gets the category of this expression.
Definition Expression.cpp:212
virtual auto GetType() const -> ExpressionType
Gets the type of this expression.
Definition Expression.cpp:220
Definition SimplifyVisitor.hpp:25
Definition Visit.hpp:13
Definition BinaryExpression.hpp:32
An expression concept.
Definition Concepts.hpp:30
A concept for an operand of a binary expression.
Definition BinaryExpression.hpp:29
Checks if type T is same as any of the provided types in U.
Definition Concepts.hpp:53
T emplace_back(T... args)
T end(T... args)
T erase(T... args)
T find_if(T... args)
T front(T... args)
T insert(T... args)
T is_same_v
Definition Add.hpp:11
auto BuildFromVector(const std::vector< std::unique_ptr< Expression > > &ops) -> std::unique_ptr< T< Expression, Expression > >
Builds a reasonably balanced binary expression from a vector of operands.
Definition BinaryExpression.hpp:42
boost::anys::unique_any any
Definition Expression.hpp:15
@ Commutative
Definition Expression.hpp:51
@ Associative
Definition Expression.hpp:50
T next(T... args)
T resize(T... args)
T transform(T... args)