OASIS
Open Algebra Software
Loading...
Searching...
No Matches
BoundedBinaryExpression.hpp
Go to the documentation of this file.
1//
2// Created by Matthew McCall on 5/2/24.
3//
4
5#ifndef OASIS_BOUNDEDBINARYEXPRESSION_HPP
6#define OASIS_BOUNDEDBINARYEXPRESSION_HPP
7
9#include "Expression.hpp"
10#include "RecursiveCast.hpp"
11#include "Serialization.hpp"
12
13namespace Oasis {
14template <template <IExpression, IExpression, IExpression, IExpression> class DerivedT, IExpression MostSigOpT = Expression, IExpression LeastSigOpT = MostSigOpT, IExpression LowerBoundT = Expression, IExpression UpperBoundT = LowerBoundT>
15class BoundedBinaryExpression : public BoundedExpression<LowerBoundT, UpperBoundT> {
16
17 using DerivedSpecialized = DerivedT<MostSigOpT, LeastSigOpT, LowerBoundT, UpperBoundT>;
18 using DerivedGeneralized = DerivedT<Expression, Expression, Expression, Expression>;
19
20public:
23 {
24 if (other.HasMostSigOp()) {
26 }
27
28 if (other.HasLeastSigOp()) {
30 }
31 }
32
33 BoundedBinaryExpression(const MostSigOpT& mostSigOp, const LeastSigOpT& leastSigOp)
34 {
35 SetMostSigOp(mostSigOp);
36 SetLeastSigOp(leastSigOp);
37 }
38
39 [[nodiscard]] auto Copy() const -> std::unique_ptr<Expression> final
40 {
41 return std::make_unique<DerivedSpecialized>(*static_cast<const DerivedSpecialized*>(this));
42 }
43
44 auto Copy(tf::Subflow& subflow) const -> std::unique_ptr<Expression> final
45 {
46 DerivedSpecialized copy;
47
48 if (this->mostSigOp) {
49 subflow.emplace([this, &copy](tf::Subflow& sbf) {
50 copy.SetMostSigOp(mostSigOp->Copy(sbf), sbf);
51 });
52 }
53
54 if (this->leastSigOp) {
55 subflow.emplace([this, &copy](tf::Subflow& sbf) {
56 copy.SetLeastSigOp(leastSigOp->Copy(sbf), sbf);
57 });
58 }
59
60 subflow.join();
61
63 }
64 [[nodiscard]] auto Differentiate(const Expression& differentiationVariable) const -> std::unique_ptr<Expression> override
65 {
66 return Generalize()->Differentiate(differentiationVariable);
67 }
68 [[nodiscard]] auto Equals(const Expression& other) const -> bool final
69 {
70 if (this->GetType() != other.GetType()) {
71 return false;
72 }
73
74 const auto otherGeneralized = other.Generalize();
75 const auto& otherBinaryGeneralized = static_cast<const DerivedGeneralized&>(*otherGeneralized);
76
77 bool mostSigOpMismatch = false, leastSigOpMismatch = false;
78
79 if (this->HasMostSigOp() == otherBinaryGeneralized.HasMostSigOp()) {
80 if (mostSigOp && otherBinaryGeneralized.HasMostSigOp()) {
81 mostSigOpMismatch = !mostSigOp->Equals(otherBinaryGeneralized.GetMostSigOp());
82 }
83 } else {
84 mostSigOpMismatch = true;
85 }
86
87 if (this->HasLeastSigOp() == otherBinaryGeneralized.HasLeastSigOp()) {
88 if (this->HasLeastSigOp() && otherBinaryGeneralized.HasLeastSigOp()) {
89 leastSigOpMismatch = !leastSigOp->Equals(otherBinaryGeneralized.GetLeastSigOp());
90 }
91 } else {
92 mostSigOpMismatch = true;
93 }
94
95 if (!mostSigOpMismatch && !leastSigOpMismatch) {
96 return true;
97 }
98
99 if (!(this->GetCategory() & Associative)) {
100 return false;
101 }
102
103 auto thisFlattened = std::vector<std::unique_ptr<Expression>> {};
104 auto otherFlattened = std::vector<std::unique_ptr<Expression>> {};
105
106 this->Flatten(thisFlattened);
107 otherBinaryGeneralized.Flatten(otherFlattened);
108
109 for (const auto& thisOperand : thisFlattened) {
110 if (std::find_if(otherFlattened.begin(), otherFlattened.end(), [&thisOperand](const auto& otherOperand) {
111 return thisOperand->Equals(*otherOperand);
112 })
113 == otherFlattened.end()) {
114 return false;
115 }
116 }
117
118 return true;
119 }
120
121 [[nodiscard]] auto Generalize() const -> std::unique_ptr<Expression> final
122 {
123 DerivedGeneralized generalized;
124
125 if (this->mostSigOp) {
126 generalized.SetMostSigOp(*this->mostSigOp->Copy());
127 }
128
129 if (this->leastSigOp) {
130 generalized.SetLeastSigOp(*this->leastSigOp->Copy());
131 }
132
133 return std::make_unique<DerivedGeneralized>(generalized);
134 }
135
136 auto Generalize(tf::Subflow& subflow) const -> std::unique_ptr<Expression> final
137 {
138 DerivedGeneralized generalized;
139
140 if (this->mostSigOp) {
141 subflow.emplace([this, &generalized](tf::Subflow& sbf) {
142 generalized.SetMostSigOp(*this->mostSigOp->Copy(sbf));
143 });
144 }
145
146 if (this->leastSigOp) {
147 subflow.emplace([this, &generalized](tf::Subflow& sbf) {
148 generalized.SetLeastSigOp(*this->leastSigOp->Copy(sbf));
149 });
150 }
151
152 subflow.join();
153
154 return std::make_unique<DerivedGeneralized>(generalized);
155 }
156
157 [[nodiscard]] auto Simplify() const -> std::unique_ptr<Expression> override
158 {
159 return Generalize()->Simplify();
160 }
161
162 auto Simplify(tf::Subflow& subflow) const -> std::unique_ptr<Expression> override
163 {
164 std::unique_ptr<Expression> generalized, simplified;
165
166 tf::Task generalizeTask = subflow.emplace([this, &generalized](tf::Subflow& sbf) {
167 generalized = Generalize(sbf);
168 });
169
170 tf::Task simplifyTask = subflow.emplace([&generalized, &simplified](tf::Subflow& sbf) {
171 simplified = generalized->Simplify(sbf);
172 });
173
174 simplifyTask.succeed(generalizeTask);
175 subflow.join();
176
177 return simplified;
178 }
179
180 [[nodiscard]] auto StructurallyEquivalent(const Expression& other) const -> bool final
181 {
182 if (this->GetType() != other.GetType()) {
183 return false;
184 }
185
186 const std::unique_ptr<Expression> otherGeneralized = other.Generalize();
187 const auto& otherBinaryGeneralized = static_cast<const DerivedGeneralized&>(*otherGeneralized);
188
189 if (this->HasMostSigOp() == otherBinaryGeneralized.HasMostSigOp()) {
190 if (this->HasMostSigOp() && otherBinaryGeneralized.HasMostSigOp()) {
191 if (!mostSigOp->StructurallyEquivalent(otherBinaryGeneralized.GetMostSigOp())) {
192 return false;
193 }
194 }
195 }
196
197 if (this->HasLeastSigOp() == otherBinaryGeneralized.HasLeastSigOp()) {
198 if (this->HasLeastSigOp() && otherBinaryGeneralized.HasLeastSigOp()) {
199 if (!leastSigOp->StructurallyEquivalent(otherBinaryGeneralized.GetLeastSigOp())) {
200 return false;
201 }
202 }
203 }
204
205 return true;
206 }
207
208 auto StructurallyEquivalent(const Expression& other, tf::Subflow& subflow) const -> bool final
209 {
210 if (this->GetType() != other.GetType()) {
211 return false;
212 }
213
214 std::unique_ptr<Expression> otherGeneralized;
215
216 tf::Task generalizeTask = subflow.emplace([&](tf::Subflow& sbf) {
217 otherGeneralized = other.Generalize(sbf);
218 });
219
220 bool mostSigOpEquivalent = false, leastSigOpEquivalent = false;
221
222 if (this->mostSigOp) {
223 tf::Task compMostSigOp = subflow.emplace([this, &otherGeneralized, &mostSigOpEquivalent](tf::Subflow& sbf) {
224 if (const auto& otherBinary = static_cast<const DerivedGeneralized&>(*otherGeneralized); otherBinary.HasMostSigOp()) {
225 mostSigOpEquivalent = mostSigOp->StructurallyEquivalent(otherBinary.GetMostSigOp(), sbf);
226 }
227 });
228
229 compMostSigOp.succeed(generalizeTask);
230 }
231
232 if (this->leastSigOp) {
233 tf::Task compLeastSigOp = subflow.emplace([this, &otherGeneralized, &leastSigOpEquivalent](tf::Subflow& sbf) {
234 if (const auto& otherBinary = static_cast<const DerivedGeneralized&>(*otherGeneralized); otherBinary.HasLeastSigOp()) {
235 leastSigOpEquivalent = leastSigOp->StructurallyEquivalent(otherBinary.GetLeastSigOp(), sbf);
236 }
237 });
238
239 compLeastSigOp.succeed(generalizeTask);
240 }
241
242 subflow.join();
243
244 return mostSigOpEquivalent && leastSigOpEquivalent;
245 }
246
259 {
260 if (this->mostSigOp->template Is<DerivedT>()) {
261 auto generalizedMostSigOp = this->mostSigOp->Generalize();
262 const auto& mostSigOp = static_cast<const DerivedGeneralized&>(*generalizedMostSigOp);
263 mostSigOp.Flatten(out);
264 } else {
265 out.push_back(this->mostSigOp->Copy());
266 }
267
268 if (this->leastSigOp->template Is<DerivedT>()) {
269 auto generalizedLeastSigOp = this->leastSigOp->Generalize();
270 const auto& leastSigOp = static_cast<const DerivedGeneralized&>(*generalizedLeastSigOp);
271 leastSigOp.Flatten(out);
272 } else {
273 out.push_back(this->leastSigOp->Copy());
274 }
275 }
276
281 auto GetMostSigOp() const -> const MostSigOpT&
282 {
283 assert(mostSigOp != nullptr);
284 return *mostSigOp;
285 }
286
291 auto GetLeastSigOp() const -> const LeastSigOpT&
292 {
293 assert(leastSigOp != nullptr);
294 return *leastSigOp;
295 }
296
301 [[nodiscard]] auto HasMostSigOp() const -> bool
302 {
303 return mostSigOp != nullptr;
304 }
305
310 [[nodiscard]] auto HasLeastSigOp() const -> bool
311 {
312 return leastSigOp != nullptr;
313 }
314
319 template <typename T>
321 auto SetMostSigOp(const T& op) -> bool
322 {
324 this->mostSigOp = op.Copy();
325 return true;
326 }
327
329 this->mostSigOp = std::make_unique<MostSigOpT>(op);
330 return true;
331 }
332
333 if (auto castedOp = Oasis::RecursiveCast<MostSigOpT>(op); castedOp) {
334 mostSigOp = std::move(castedOp);
335 return true;
336 }
337
338 return false;
339 }
340
345 template <typename T>
347 auto SetLeastSigOp(const T& op) -> bool
348 {
350 this->leastSigOp = op.Copy();
351 return true;
352 }
353
355 this->leastSigOp = std::make_unique<LeastSigOpT>(op);
356 return true;
357 }
358
359 if (auto castedOp = Oasis::RecursiveCast<LeastSigOpT>(op); castedOp) {
360 leastSigOp = std::move(castedOp);
361 return true;
362 }
363
364 return false;
365 }
366
367 auto Substitute(const Expression& var, const Expression& val) -> std::unique_ptr<Expression> override
368 {
369 const std::unique_ptr<Expression> left = GetMostSigOp().Substitute(var, val);
370 const std::unique_ptr<Expression> right = GetLeastSigOp().Substitute(var, val);
371 DerivedGeneralized comb { *left, *right };
372 auto ret = comb.Simplify();
373 return ret;
374 }
379 auto SwapOperands() const -> DerivedSpecialized
380 {
381 return DerivedT { *this->leastSigOp, *this->mostSigOp };
382 }
383
385
386 void Serialize(SerializationVisitor& visitor) const override
387 {
388 const auto generalized = Generalize();
389 const auto& derivedGeneralized = dynamic_cast<const DerivedGeneralized&>(*generalized);
390 visitor.Serialize(derivedGeneralized);
391 }
392
393private:
396};
397
398}
399
400#endif // OASIS_BOUNDEDBINARYEXPRESSION_HPP
Definition BoundedBinaryExpression.hpp:15
auto Differentiate(const Expression &differentiationVariable) const -> std::unique_ptr< Expression > override
Tries to differentiate this function.
Definition BoundedBinaryExpression.hpp:64
auto SwapOperands() const -> DerivedSpecialized
Swaps the operands of this expression.
Definition BoundedBinaryExpression.hpp:379
auto GetMostSigOp() const -> const MostSigOpT &
Gets the most significant operand of this expression.
Definition BoundedBinaryExpression.hpp:281
auto Flatten(std::vector< std::unique_ptr< Expression > > &out) const -> void
Flattens this expression.
Definition BoundedBinaryExpression.hpp:258
auto Simplify(tf::Subflow &subflow) const -> std::unique_ptr< Expression > override
Definition BoundedBinaryExpression.hpp:162
BoundedBinaryExpression(const MostSigOpT &mostSigOp, const LeastSigOpT &leastSigOp)
Definition BoundedBinaryExpression.hpp:33
auto Simplify() const -> std::unique_ptr< Expression > override
Simplifies this expression.
Definition BoundedBinaryExpression.hpp:157
auto Substitute(const Expression &var, const Expression &val) -> std::unique_ptr< Expression > override
Definition BoundedBinaryExpression.hpp:367
auto Equals(const Expression &other) const -> bool final
Compares this expression to another expression for equality.
Definition BoundedBinaryExpression.hpp:68
auto SetMostSigOp(const T &op) -> bool
Sets the most significant operand of this expression.
Definition BoundedBinaryExpression.hpp:321
auto Generalize() const -> std::unique_ptr< Expression > final
Converts this expression to a more general expression.
Definition BoundedBinaryExpression.hpp:121
auto HasMostSigOp() const -> bool
Gets whether this expression has a most significant operand.
Definition BoundedBinaryExpression.hpp:301
void Serialize(SerializationVisitor &visitor) const override
This function serializes the expression object.
Definition BoundedBinaryExpression.hpp:386
auto GetLeastSigOp() const -> const LeastSigOpT &
Gets the least significant operand of this expression.
Definition BoundedBinaryExpression.hpp:291
auto HasLeastSigOp() const -> bool
Gets whether this expression has a least significant operand.
Definition BoundedBinaryExpression.hpp:310
auto SetLeastSigOp(const T &op) -> bool
Sets the least significant operand of this expression.
Definition BoundedBinaryExpression.hpp:347
auto Generalize(tf::Subflow &subflow) const -> std::unique_ptr< Expression > final
Definition BoundedBinaryExpression.hpp:136
auto Copy(tf::Subflow &subflow) const -> std::unique_ptr< Expression > final
Definition BoundedBinaryExpression.hpp:44
auto Copy() const -> std::unique_ptr< Expression > final
Copies this expression.
Definition BoundedBinaryExpression.hpp:39
auto StructurallyEquivalent(const Expression &other) const -> bool final
Checks whether this expression is structurally equivalent to another expression.
Definition BoundedBinaryExpression.hpp:180
auto operator=(const BoundedBinaryExpression &other) -> BoundedBinaryExpression &=default
auto StructurallyEquivalent(const Expression &other, tf::Subflow &subflow) const -> bool final
Definition BoundedBinaryExpression.hpp:208
BoundedBinaryExpression(const BoundedBinaryExpression &other)
Definition BoundedBinaryExpression.hpp:22
Definition BoundedExpression.hpp:15
An expression.
Definition Expression.hpp:56
virtual auto GetCategory() const -> uint32_t
Gets the category of this expression.
Definition Expression.cpp:212
virtual auto GetType() const -> ExpressionType
Gets the type of this expression.
Definition Expression.cpp:220
virtual auto Simplify() const -> std::unique_ptr< Expression >
Simplifies this expression.
Definition Expression.cpp:244
Definition Serialization.hpp:50
virtual void Serialize(const Real &real)=0
Checks if type T is same as any of the provided types in U.
Definition Concepts.hpp:48
T find_if(T... args)
T is_same_v
Definition Add.hpp:11
@ Associative
Definition Expression.hpp:43