OASIS
Open Algebra Software
Loading...
Searching...
No Matches
BinaryExpression.hpp
Go to the documentation of this file.
1//
2// Created by Matthew McCall on 7/2/23.
3//
4
5#ifndef OASIS_BINARYEXPRESSION_HPP
6#define OASIS_BINARYEXPRESSION_HPP
7
8#include <algorithm>
9#include <cassert>
10#include <functional>
11#include <list>
12
13#include "Expression.hpp"
15#include "RecursiveCast.hpp"
16#include "Visit.hpp"
17
18namespace Oasis {
28template <typename MostSigOpT, typename LeastSigOpT, typename T>
30
31template <template <typename, typename> typename T>
32concept IAssociativeAndCommutative = IExpression<T<Expression, Expression>> && ((T<Expression, Expression>::GetStaticCategory() & (Associative | Commutative)) == (Associative | Commutative));
33
40template <template <typename, typename> typename T>
43{
44 if (ops.size() <= 1) {
45 return nullptr;
46 }
47
48 using GeneralizedT = T<Expression, Expression>;
49
51 opsList.resize(ops.size());
52
53 std::transform(ops.begin(), ops.end(), opsList.begin(), [](const auto& op) { return op->Copy(); });
54
55 while (std::next(opsList.begin()) != opsList.end()) {
56 for (auto i = opsList.begin(); i != opsList.end() && std::next(i) != opsList.end();) {
57 auto node = std::make_unique<GeneralizedT>(**i, **std::next(i));
58 opsList.insert(i, std::move(node));
59 i = opsList.erase(i, std::next(i, 2));
60 }
61 }
62
63 auto* result = dynamic_cast<GeneralizedT*>(opsList.front().release());
64 return std::unique_ptr<GeneralizedT>(result);
65}
66
82template <template <IExpression, IExpression> class DerivedT, IExpression MostSigOpT = Expression, IExpression LeastSigOpT = MostSigOpT>
84
85 using DerivedSpecialized = DerivedT<MostSigOpT, LeastSigOpT>;
86 using DerivedGeneralized = DerivedT<Expression, Expression>;
87
88public:
89 BinaryExpression() = default;
91 {
92 if (other.HasMostSigOp()) {
94 }
95
96 if (other.HasLeastSigOp()) {
98 }
99 }
100
101 BinaryExpression(const MostSigOpT& mostSigOp, const LeastSigOpT& leastSigOp)
102 {
105 }
106
107 template <IExpression Op1T, IExpression Op2T, IExpression... OpsT>
108 BinaryExpression(const Op1T& op1, const Op2T& op2, const OpsT&... ops)
109 {
110 static_assert(IAssociativeAndCommutative<DerivedT>, "List initializer only supported for associative and commutative expressions");
111 static_assert(std::is_same_v<DerivedGeneralized, DerivedSpecialized>, "List initializer only supported for generalized expressions");
112
114
115 for (auto opWrapper : std::vector<std::reference_wrapper<const Expression>> { static_cast<const Expression&>(op1), static_cast<const Expression&>(op2), (static_cast<const Expression&>(ops))... }) {
116 const Expression& operand = opWrapper.get();
117 opsVec.emplace_back(operand.Copy());
118 }
119
120 // build expression from vector
121 auto generalized = BuildFromVector<DerivedT>(opsVec);
122
123 SetLeastSigOp(generalized->GetLeastSigOp());
124 SetMostSigOp(generalized->GetMostSigOp());
125 }
126
127 [[nodiscard]] auto Copy() const -> std::unique_ptr<Expression> final
128 {
129 return std::make_unique<DerivedSpecialized>(*static_cast<const DerivedSpecialized*>(this));
130 }
131
132 [[nodiscard]] auto Equals(const Expression& other) const -> bool final
133 {
134 if (this->GetType() != other.GetType()) {
135 return false;
136 }
137
138 const auto otherGeneralized = other.Generalize();
139 const auto& otherBinaryGeneralized = static_cast<const DerivedGeneralized&>(*otherGeneralized);
140
141 bool mostSigOpMismatch = false, leastSigOpMismatch = false;
142
143 if (this->HasMostSigOp() == otherBinaryGeneralized.HasMostSigOp()) {
144 if (mostSigOp && otherBinaryGeneralized.HasMostSigOp()) {
145 mostSigOpMismatch = !mostSigOp->Equals(otherBinaryGeneralized.GetMostSigOp());
146 }
147 } else {
148 mostSigOpMismatch = true;
149 }
150
151 if (this->HasLeastSigOp() == otherBinaryGeneralized.HasLeastSigOp()) {
152 if (this->HasLeastSigOp() && otherBinaryGeneralized.HasLeastSigOp()) {
153 leastSigOpMismatch = !leastSigOp->Equals(otherBinaryGeneralized.GetLeastSigOp());
154 }
155 } else {
156 mostSigOpMismatch = true;
157 }
158
159 if (!mostSigOpMismatch && !leastSigOpMismatch) {
160 return true;
161 }
162
163 if (!(this->GetCategory() & Associative)) {
164 return false;
165 }
166
167 auto thisFlattened = std::vector<std::unique_ptr<Expression>> {};
168 auto otherFlattened = std::vector<std::unique_ptr<Expression>> {};
169
170 this->Flatten(thisFlattened);
171 otherBinaryGeneralized.Flatten(otherFlattened);
172
173 for (const auto& thisOperand : thisFlattened) {
174 if (std::find_if(otherFlattened.begin(), otherFlattened.end(), [&thisOperand](const auto& otherOperand) {
175 return thisOperand->Equals(*otherOperand);
176 })
177 == otherFlattened.end()) {
178 return false;
179 }
180 }
181
182 return true;
183 }
184
185 [[nodiscard]] auto Generalize() const -> std::unique_ptr<Expression> final
186 {
187 DerivedGeneralized generalized;
188
189 if (this->mostSigOp) {
190 generalized.SetMostSigOp(*this->mostSigOp->Copy());
191 }
192
193 if (this->leastSigOp) {
194 generalized.SetLeastSigOp(*this->leastSigOp->Copy());
195 }
196
197 return std::make_unique<DerivedGeneralized>(generalized);
198 }
199
200 [[nodiscard]] auto Integrate(const Expression& integrationVariable) const -> std::unique_ptr<Expression> override
201 {
202 return Generalize()->Integrate(integrationVariable);
203 }
204
205 [[nodiscard]] auto StructurallyEquivalent(const Expression& other) const -> bool final
206 {
207 if (this->GetType() != other.GetType()) {
208 return false;
209 }
210
211 const std::unique_ptr<Expression> otherGeneralized = other.Generalize();
212 const auto& otherBinaryGeneralized = static_cast<const DerivedGeneralized&>(*otherGeneralized);
213
214 if (this->HasMostSigOp() == otherBinaryGeneralized.HasMostSigOp()) {
215 if (this->HasMostSigOp() && otherBinaryGeneralized.HasMostSigOp()) {
216 if (!mostSigOp->StructurallyEquivalent(otherBinaryGeneralized.GetMostSigOp())) {
217 return false;
218 }
219 }
220 }
221
222 if (this->HasLeastSigOp() == otherBinaryGeneralized.HasLeastSigOp()) {
223 if (this->HasLeastSigOp() && otherBinaryGeneralized.HasLeastSigOp()) {
224 if (!leastSigOp->StructurallyEquivalent(otherBinaryGeneralized.GetLeastSigOp())) {
225 return false;
226 }
227 }
228 }
229
230 return true;
231 }
232
245 {
246 if (mostSigOp) {
247 if (this->mostSigOp->template Is<DerivedGeneralized>()) {
248 auto generalizedMostSigOp = this->mostSigOp->Generalize();
249 const auto& mostSigOp = static_cast<const DerivedGeneralized&>(*generalizedMostSigOp);
250 mostSigOp.Flatten(out);
251 } else {
252 out.push_back(this->mostSigOp->Copy());
253 }
254 }
255
256 if (leastSigOp) {
257 if (this->leastSigOp->template Is<DerivedGeneralized>()) {
258 auto generalizedLeastSigOp = this->leastSigOp->Generalize();
259 const auto& leastSigOp = static_cast<const DerivedGeneralized&>(*generalizedLeastSigOp);
260 leastSigOp.Flatten(out);
261 } else {
262 out.push_back(this->leastSigOp->Copy());
263 }
264 }
265 }
266
271 auto GetMostSigOp() const -> const MostSigOpT&
272 {
273 assert(mostSigOp != nullptr);
274 return *mostSigOp;
275 }
276
281 auto GetLeastSigOp() const -> const LeastSigOpT&
282 {
283 assert(leastSigOp != nullptr);
284 return *leastSigOp;
285 }
286
291 [[nodiscard]] auto HasMostSigOp() const -> bool
292 {
293 return mostSigOp != nullptr;
294 }
295
300 [[nodiscard]] auto HasLeastSigOp() const -> bool
301 {
302 return leastSigOp != nullptr;
303 }
304
309 template <typename T>
311 auto SetMostSigOp(const T& op) -> bool
312 {
314 this->mostSigOp = op.Copy();
315 return true;
316 }
317
320 return true;
321 }
322
323 if (auto castedOp = Oasis::RecursiveCast<MostSigOpT>(op); castedOp) {
324 mostSigOp = std::move(castedOp);
325 return true;
326 }
327
328 return false;
329 }
330
335 template <typename T>
337 auto SetLeastSigOp(const T& op) -> bool
338 {
340 this->leastSigOp = op.Copy();
341 return true;
342 }
343
346 return true;
347 }
348
349 if (auto castedOp = Oasis::RecursiveCast<LeastSigOpT>(op); castedOp) {
350 leastSigOp = std::move(castedOp);
351 return true;
352 }
353
354 return false;
355 }
356
357 auto Substitute(const Expression& var, const Expression& val) -> std::unique_ptr<Expression> override
358 {
359 // TODO: FIX WITH VISITOR?
361 std::unique_ptr<Expression> right = ((GetLeastSigOp().Copy())->Substitute(var, val));
362 DerivedT<Expression, Expression> comb = DerivedT<Expression, Expression> { *left, *right };
363
364 Oasis::SimplifyVisitor simplifyVisitor {};
365 auto simplified = comb.Accept(simplifyVisitor);
366 if (!simplified) {
367 return comb.Generalize();
368 }
369 return std::move(simplified.value());
370 }
375 auto SwapOperands() const -> DerivedT<LeastSigOpT, MostSigOpT>
376 {
377 return DerivedT { *this->leastSigOp, *this->mostSigOp };
378 }
379
380 auto operator=(const BinaryExpression& other) -> BinaryExpression& = default;
381
382 auto AcceptInternal(Visitor& visitor) const -> any override
383 {
384 const auto generalized = Generalize();
385 const auto& derivedGeneralized = dynamic_cast<const DerivedGeneralized&>(*generalized);
386 return visitor.Visit(derivedGeneralized);
387 }
388
391};
392
393} // Oasis
394
395#endif // OASIS_BINARYEXPRESSION_HPP
T begin(T... args)
A binary expression.
Definition BinaryExpression.hpp:83
auto Copy() const -> std::unique_ptr< Expression > final
Copies this expression.
Definition BinaryExpression.hpp:127
BinaryExpression(const BinaryExpression &other)
Definition BinaryExpression.hpp:90
std::unique_ptr< LeastSigOpT > leastSigOp
Definition BinaryExpression.hpp:390
auto AcceptInternal(Visitor &visitor) const -> any override
This function serializes the expression object.
Definition BinaryExpression.hpp:382
auto Substitute(const Expression &var, const Expression &val) -> std::unique_ptr< Expression > override
Definition BinaryExpression.hpp:357
auto GetLeastSigOp() const -> const LeastSigOpT &
Gets the least significant operand of this expression.
Definition BinaryExpression.hpp:281
auto operator=(const BinaryExpression &other) -> BinaryExpression &=default
auto SetLeastSigOp(const T &op) -> bool
Sets the least significant operand of this expression.
Definition BinaryExpression.hpp:337
auto Integrate(const Expression &integrationVariable) const -> std::unique_ptr< Expression > override
Attempts to integrate this expression using integration rules.
Definition BinaryExpression.hpp:200
auto Generalize() const -> std::unique_ptr< Expression > final
Converts this expression to a more general expression.
Definition BinaryExpression.hpp:185
auto GetMostSigOp() const -> const MostSigOpT &
Gets the most significant operand of this expression.
Definition BinaryExpression.hpp:271
std::unique_ptr< MostSigOpT > mostSigOp
Definition BinaryExpression.hpp:389
auto SwapOperands() const -> DerivedT< LeastSigOpT, MostSigOpT >
Swaps the operands of this expression.
Definition BinaryExpression.hpp:375
auto SetMostSigOp(const T &op) -> bool
Sets the most significant operand of this expression.
Definition BinaryExpression.hpp:311
auto StructurallyEquivalent(const Expression &other) const -> bool final
Checks whether this expression is structurally equivalent to another expression.
Definition BinaryExpression.hpp:205
auto Flatten(std::vector< std::unique_ptr< Expression > > &out) const -> void
Flattens this expression.
Definition BinaryExpression.hpp:244
auto HasLeastSigOp() const -> bool
Gets whether this expression has a least significant operand.
Definition BinaryExpression.hpp:300
BinaryExpression(const Op1T &op1, const Op2T &op2, const OpsT &... ops)
Definition BinaryExpression.hpp:108
auto Equals(const Expression &other) const -> bool final
Compares this expression to another expression for equality.
Definition BinaryExpression.hpp:132
BinaryExpression(const MostSigOpT &mostSigOp, const LeastSigOpT &leastSigOp)
Definition BinaryExpression.hpp:101
auto HasMostSigOp() const -> bool
Gets whether this expression has a most significant operand.
Definition BinaryExpression.hpp:291
An expression.
Definition Expression.hpp:63
virtual auto Copy() const -> std::unique_ptr< Expression >=0
Copies this expression.
virtual auto GetCategory() const -> uint32_t
Gets the category of this expression.
Definition Expression.cpp:221
virtual auto GetType() const -> ExpressionType
Gets the type of this expression.
Definition Expression.cpp:236
Definition SimplifyVisitor.hpp:30
Definition Visit.hpp:13
Definition BinaryExpression.hpp:32
An expression concept.
Definition Concepts.hpp:30
A concept for an operand of a binary expression.
Definition BinaryExpression.hpp:29
Checks if type T is same as any of the provided types in U.
Definition Concepts.hpp:53
T emplace_back(T... args)
T end(T... args)
T erase(T... args)
T find_if(T... args)
T front(T... args)
T insert(T... args)
T is_same_v
Definition Add.hpp:11
auto BuildFromVector(const std::vector< std::unique_ptr< Expression > > &ops) -> std::unique_ptr< T< Expression, Expression > >
Builds a reasonably balanced binary expression from a vector of operands.
Definition BinaryExpression.hpp:42
boost::anys::unique_any any
Definition Expression.hpp:15
@ Commutative
Definition Expression.hpp:51
@ Associative
Definition Expression.hpp:50
T next(T... args)
T resize(T... args)
T transform(T... args)