OASIS
Open Algebra Software
Loading...
Searching...
No Matches
BinaryExpression.hpp
Go to the documentation of this file.
1//
2// Created by Matthew McCall on 7/2/23.
3//
4
5#ifndef OASIS_BINARYEXPRESSION_HPP
6#define OASIS_BINARYEXPRESSION_HPP
7
8#include <algorithm>
9#include <cassert>
10#include <functional>
11#include <list>
12
13#include "Expression.hpp"
15#include "RecursiveCast.hpp"
16#include "Visit.hpp"
17
18namespace Oasis {
28template <typename MostSigOpT, typename LeastSigOpT, typename T>
30
31template <template <typename, typename> typename T>
32concept IAssociativeAndCommutative = IExpression<T<Expression, Expression>> && ((T<Expression, Expression>::GetStaticCategory() & (Associative | Commutative)) == (Associative | Commutative));
33
40template <template <typename, typename> typename T>
43{
44 if (ops.size() <= 1) {
45 return nullptr;
46 }
47
48 using GeneralizedT = T<Expression, Expression>;
49
51 opsList.resize(ops.size());
52
53 std::transform(ops.begin(), ops.end(), opsList.begin(), [](const auto& op) { return op->Copy(); });
54
55 while (std::next(opsList.begin()) != opsList.end()) {
56 for (auto i = opsList.begin(); i != opsList.end() && std::next(i) != opsList.end();) {
57 auto node = std::make_unique<GeneralizedT>(**i, **std::next(i));
58 opsList.insert(i, std::move(node));
59 i = opsList.erase(i, std::next(i, 2));
60 }
61 }
62
63 auto* result = dynamic_cast<GeneralizedT*>(opsList.front().release());
64 return std::unique_ptr<GeneralizedT>(result);
65}
66
82template <template <IExpression, IExpression> class DerivedT, IExpression MostSigOpT = Expression, IExpression LeastSigOpT = MostSigOpT>
84
85 using DerivedSpecialized = DerivedT<MostSigOpT, LeastSigOpT>;
86 using DerivedGeneralized = DerivedT<Expression, Expression>;
87
88public:
89 BinaryExpression() = default;
91 {
92 if (other.HasMostSigOp()) {
94 }
95
96 if (other.HasLeastSigOp()) {
98 }
99 }
100
101 BinaryExpression(const MostSigOpT& mostSigOp, const LeastSigOpT& leastSigOp)
102 {
105 }
106
107 template <IExpression Op1T, IExpression Op2T, IExpression... OpsT>
108 BinaryExpression(const Op1T& op1, const Op2T& op2, const OpsT&... ops)
109 {
110 static_assert(IAssociativeAndCommutative<DerivedT>, "List initializer only supported for associative and commutative expressions");
111 static_assert(std::is_same_v<DerivedGeneralized, DerivedSpecialized>, "List initializer only supported for generalized expressions");
112
114
115 for (auto opWrapper : std::vector<std::reference_wrapper<const Expression>> { static_cast<const Expression&>(op1), static_cast<const Expression&>(op2), (static_cast<const Expression&>(ops))... }) {
116 const Expression& operand = opWrapper.get();
117 opsVec.emplace_back(operand.Copy());
118 }
119
120 // build expression from vector
121 auto generalized = BuildFromVector<DerivedT>(opsVec);
122
123 SetLeastSigOp(generalized->GetLeastSigOp());
124 SetMostSigOp(generalized->GetMostSigOp());
125 }
126
127 [[nodiscard]] auto Copy() const -> std::unique_ptr<Expression> final
128 {
129 return std::make_unique<DerivedSpecialized>(*static_cast<const DerivedSpecialized*>(this));
130 }
131
132 [[nodiscard]] auto Differentiate(const Expression& differentiationVariable) const -> std::unique_ptr<Expression> override
133 {
134 return Generalize()->Differentiate(differentiationVariable);
135 }
136 [[nodiscard]] auto Equals(const Expression& other) const -> bool final
137 {
138 if (this->GetType() != other.GetType()) {
139 return false;
140 }
141
142 const auto otherGeneralized = other.Generalize();
143 const auto& otherBinaryGeneralized = static_cast<const DerivedGeneralized&>(*otherGeneralized);
144
145 bool mostSigOpMismatch = false, leastSigOpMismatch = false;
146
147 if (this->HasMostSigOp() == otherBinaryGeneralized.HasMostSigOp()) {
148 if (mostSigOp && otherBinaryGeneralized.HasMostSigOp()) {
149 mostSigOpMismatch = !mostSigOp->Equals(otherBinaryGeneralized.GetMostSigOp());
150 }
151 } else {
152 mostSigOpMismatch = true;
153 }
154
155 if (this->HasLeastSigOp() == otherBinaryGeneralized.HasLeastSigOp()) {
156 if (this->HasLeastSigOp() && otherBinaryGeneralized.HasLeastSigOp()) {
157 leastSigOpMismatch = !leastSigOp->Equals(otherBinaryGeneralized.GetLeastSigOp());
158 }
159 } else {
160 mostSigOpMismatch = true;
161 }
162
163 if (!mostSigOpMismatch && !leastSigOpMismatch) {
164 return true;
165 }
166
167 if (!(this->GetCategory() & Associative)) {
168 return false;
169 }
170
171 auto thisFlattened = std::vector<std::unique_ptr<Expression>> {};
172 auto otherFlattened = std::vector<std::unique_ptr<Expression>> {};
173
174 this->Flatten(thisFlattened);
175 otherBinaryGeneralized.Flatten(otherFlattened);
176
177 for (const auto& thisOperand : thisFlattened) {
178 if (std::find_if(otherFlattened.begin(), otherFlattened.end(), [&thisOperand](const auto& otherOperand) {
179 return thisOperand->Equals(*otherOperand);
180 })
181 == otherFlattened.end()) {
182 return false;
183 }
184 }
185
186 return true;
187 }
188
189 [[nodiscard]] auto Generalize() const -> std::unique_ptr<Expression> final
190 {
191 DerivedGeneralized generalized;
192
193 if (this->mostSigOp) {
194 generalized.SetMostSigOp(*this->mostSigOp->Copy());
195 }
196
197 if (this->leastSigOp) {
198 generalized.SetLeastSigOp(*this->leastSigOp->Copy());
199 }
200
201 return std::make_unique<DerivedGeneralized>(generalized);
202 }
203
204 [[nodiscard]] auto Integrate(const Expression& integrationVariable) const -> std::unique_ptr<Expression> override
205 {
206 return Generalize()->Integrate(integrationVariable);
207 }
208
209 [[nodiscard]] auto StructurallyEquivalent(const Expression& other) const -> bool final
210 {
211 if (this->GetType() != other.GetType()) {
212 return false;
213 }
214
215 const std::unique_ptr<Expression> otherGeneralized = other.Generalize();
216 const auto& otherBinaryGeneralized = static_cast<const DerivedGeneralized&>(*otherGeneralized);
217
218 if (this->HasMostSigOp() == otherBinaryGeneralized.HasMostSigOp()) {
219 if (this->HasMostSigOp() && otherBinaryGeneralized.HasMostSigOp()) {
220 if (!mostSigOp->StructurallyEquivalent(otherBinaryGeneralized.GetMostSigOp())) {
221 return false;
222 }
223 }
224 }
225
226 if (this->HasLeastSigOp() == otherBinaryGeneralized.HasLeastSigOp()) {
227 if (this->HasLeastSigOp() && otherBinaryGeneralized.HasLeastSigOp()) {
228 if (!leastSigOp->StructurallyEquivalent(otherBinaryGeneralized.GetLeastSigOp())) {
229 return false;
230 }
231 }
232 }
233
234 return true;
235 }
236
249 {
250 if (mostSigOp) {
251 if (this->mostSigOp->template Is<DerivedGeneralized>()) {
252 auto generalizedMostSigOp = this->mostSigOp->Generalize();
253 const auto& mostSigOp = static_cast<const DerivedGeneralized&>(*generalizedMostSigOp);
254 mostSigOp.Flatten(out);
255 } else {
256 out.push_back(this->mostSigOp->Copy());
257 }
258 }
259
260 if (leastSigOp) {
261 if (this->leastSigOp->template Is<DerivedGeneralized>()) {
262 auto generalizedLeastSigOp = this->leastSigOp->Generalize();
263 const auto& leastSigOp = static_cast<const DerivedGeneralized&>(*generalizedLeastSigOp);
264 leastSigOp.Flatten(out);
265 } else {
266 out.push_back(this->leastSigOp->Copy());
267 }
268 }
269 }
270
275 auto GetMostSigOp() const -> const MostSigOpT&
276 {
277 assert(mostSigOp != nullptr);
278 return *mostSigOp;
279 }
280
285 auto GetLeastSigOp() const -> const LeastSigOpT&
286 {
287 assert(leastSigOp != nullptr);
288 return *leastSigOp;
289 }
290
295 [[nodiscard]] auto HasMostSigOp() const -> bool
296 {
297 return mostSigOp != nullptr;
298 }
299
304 [[nodiscard]] auto HasLeastSigOp() const -> bool
305 {
306 return leastSigOp != nullptr;
307 }
308
313 template <typename T>
315 auto SetMostSigOp(const T& op) -> bool
316 {
318 this->mostSigOp = op.Copy();
319 return true;
320 }
321
324 return true;
325 }
326
327 if (auto castedOp = Oasis::RecursiveCast<MostSigOpT>(op); castedOp) {
328 mostSigOp = std::move(castedOp);
329 return true;
330 }
331
332 return false;
333 }
334
339 template <typename T>
341 auto SetLeastSigOp(const T& op) -> bool
342 {
344 this->leastSigOp = op.Copy();
345 return true;
346 }
347
350 return true;
351 }
352
353 if (auto castedOp = Oasis::RecursiveCast<LeastSigOpT>(op); castedOp) {
354 leastSigOp = std::move(castedOp);
355 return true;
356 }
357
358 return false;
359 }
360
361 auto Substitute(const Expression& var, const Expression& val) -> std::unique_ptr<Expression> override
362 {
363 // TODO: FIX WITH VISITOR?
365 std::unique_ptr<Expression> right = ((GetLeastSigOp().Copy())->Substitute(var, val));
366 DerivedT<Expression, Expression> comb = DerivedT<Expression, Expression> { *left, *right };
367
368 Oasis::SimplifyVisitor simplifyVisitor {};
369 auto simplified = comb.Accept(simplifyVisitor);
370 if (!simplified) {
371 return comb.Generalize();
372 }
373 return std::move(simplified.value());
374 }
379 auto SwapOperands() const -> DerivedT<LeastSigOpT, MostSigOpT>
380 {
381 return DerivedT { *this->leastSigOp, *this->mostSigOp };
382 }
383
384 auto operator=(const BinaryExpression& other) -> BinaryExpression& = default;
385
386 auto AcceptInternal(Visitor& visitor) const -> any override
387 {
388 const auto generalized = Generalize();
389 const auto& derivedGeneralized = dynamic_cast<const DerivedGeneralized&>(*generalized);
390 return visitor.Visit(derivedGeneralized);
391 }
392
395};
396
397} // Oasis
398
399#endif // OASIS_BINARYEXPRESSION_HPP
T begin(T... args)
A binary expression.
Definition BinaryExpression.hpp:83
auto Copy() const -> std::unique_ptr< Expression > final
Copies this expression.
Definition BinaryExpression.hpp:127
BinaryExpression(const BinaryExpression &other)
Definition BinaryExpression.hpp:90
std::unique_ptr< LeastSigOpT > leastSigOp
Definition BinaryExpression.hpp:394
auto Differentiate(const Expression &differentiationVariable) const -> std::unique_ptr< Expression > override
Tries to differentiate this function.
Definition BinaryExpression.hpp:132
auto AcceptInternal(Visitor &visitor) const -> any override
This function serializes the expression object.
Definition BinaryExpression.hpp:386
auto Substitute(const Expression &var, const Expression &val) -> std::unique_ptr< Expression > override
Definition BinaryExpression.hpp:361
auto GetLeastSigOp() const -> const LeastSigOpT &
Gets the least significant operand of this expression.
Definition BinaryExpression.hpp:285
auto operator=(const BinaryExpression &other) -> BinaryExpression &=default
auto SetLeastSigOp(const T &op) -> bool
Sets the least significant operand of this expression.
Definition BinaryExpression.hpp:341
auto Integrate(const Expression &integrationVariable) const -> std::unique_ptr< Expression > override
Attempts to integrate this expression using integration rules.
Definition BinaryExpression.hpp:204
auto Generalize() const -> std::unique_ptr< Expression > final
Converts this expression to a more general expression.
Definition BinaryExpression.hpp:189
auto GetMostSigOp() const -> const MostSigOpT &
Gets the most significant operand of this expression.
Definition BinaryExpression.hpp:275
std::unique_ptr< MostSigOpT > mostSigOp
Definition BinaryExpression.hpp:393
auto SwapOperands() const -> DerivedT< LeastSigOpT, MostSigOpT >
Swaps the operands of this expression.
Definition BinaryExpression.hpp:379
auto SetMostSigOp(const T &op) -> bool
Sets the most significant operand of this expression.
Definition BinaryExpression.hpp:315
auto StructurallyEquivalent(const Expression &other) const -> bool final
Checks whether this expression is structurally equivalent to another expression.
Definition BinaryExpression.hpp:209
auto Flatten(std::vector< std::unique_ptr< Expression > > &out) const -> void
Flattens this expression.
Definition BinaryExpression.hpp:248
auto HasLeastSigOp() const -> bool
Gets whether this expression has a least significant operand.
Definition BinaryExpression.hpp:304
BinaryExpression(const Op1T &op1, const Op2T &op2, const OpsT &... ops)
Definition BinaryExpression.hpp:108
auto Equals(const Expression &other) const -> bool final
Compares this expression to another expression for equality.
Definition BinaryExpression.hpp:136
BinaryExpression(const MostSigOpT &mostSigOp, const LeastSigOpT &leastSigOp)
Definition BinaryExpression.hpp:101
auto HasMostSigOp() const -> bool
Gets whether this expression has a most significant operand.
Definition BinaryExpression.hpp:295
An expression.
Definition Expression.hpp:63
virtual auto Copy() const -> std::unique_ptr< Expression >=0
Copies this expression.
virtual auto GetCategory() const -> uint32_t
Gets the category of this expression.
Definition Expression.cpp:219
virtual auto GetType() const -> ExpressionType
Gets the type of this expression.
Definition Expression.cpp:229
Definition SimplifyVisitor.hpp:25
Definition Visit.hpp:13
Definition BinaryExpression.hpp:32
An expression concept.
Definition Concepts.hpp:30
A concept for an operand of a binary expression.
Definition BinaryExpression.hpp:29
Checks if type T is same as any of the provided types in U.
Definition Concepts.hpp:53
T emplace_back(T... args)
T end(T... args)
T erase(T... args)
T find_if(T... args)
T front(T... args)
T insert(T... args)
T is_same_v
Definition Add.hpp:11
auto BuildFromVector(const std::vector< std::unique_ptr< Expression > > &ops) -> std::unique_ptr< T< Expression, Expression > >
Builds a reasonably balanced binary expression from a vector of operands.
Definition BinaryExpression.hpp:42
boost::anys::unique_any any
Definition Expression.hpp:15
@ Commutative
Definition Expression.hpp:51
@ Associative
Definition Expression.hpp:50
T next(T... args)
T resize(T... args)
T transform(T... args)