00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014 #include "base/stl_util.h"
00015 #include "constraint_solver/constraint_solveri.h"
00016 #include "examples/global_arith.h"
00017
00018 namespace operations_research {
00019
00020 class ArithmeticPropagator;
00021
00022
00023
00024 class SubstitutionMap {
00025 public:
00026 struct Offset {
00027 Offset() : var_index(-1), offset(0) {}
00028 Offset(int v, int64 o) : var_index(v), offset(o) {}
00029 int var_index;
00030 int64 offset;
00031 };
00032
00033 void AddSubstitution(int left_var, int right_var, int64 right_offset) {
00034
00035 substitutions_[left_var] = Offset(right_var, right_offset);
00036 }
00037
00038 void ProcessAllSubstitutions(Callback3<int, int, int64>* const hook) {
00039 for (hash_map<int, Offset>::const_iterator it = substitutions_.begin();
00040 it != substitutions_.end();
00041 ++it) {
00042 hook->Run(it->first, it->second.var_index, it->second.offset);
00043 }
00044 }
00045 private:
00046 hash_map<int, Offset> substitutions_;
00047 };
00048
00049
00050
00051 struct Bounds {
00052 Bounds() : lb(kint64min), ub(kint64max) {}
00053 Bounds(int64 l, int64 u) : lb(l), ub(u) {}
00054
00055 void Intersect(int64 new_lb, int64 new_ub) {
00056 lb = std::max(lb, new_lb);
00057 ub = std::min(ub, new_ub);
00058 }
00059
00060 void Intersect(const Bounds& other) {
00061 Intersect(other.lb, other.ub);
00062 }
00063
00064 void Union(int64 new_lb, int64 new_ub) {
00065 lb = std::min(lb, new_lb);
00066 ub = std::max(ub, new_ub);
00067 }
00068
00069 void Union(const Bounds& other) {
00070 Union(other.lb, other.ub);
00071 }
00072
00073 bool IsEqual(const Bounds& other) {
00074 return (ub == other.ub && lb == other.lb);
00075 }
00076
00077 bool IsIncluded(const Bounds& other) {
00078 return (ub <= other.ub && lb >= other.lb);
00079 }
00080
00081 int64 lb;
00082 int64 ub;
00083 };
00084
00085
00086
00087 class BoundsStore {
00088 public:
00089 BoundsStore(vector<Bounds>* initial_bounds)
00090 : initial_bounds_(initial_bounds) {}
00091
00092 void SetRange(int var_index, int64 lb, int64 ub) {
00093 hash_map<int, Bounds>::iterator it = modified_bounds_.find(var_index);
00094 if (it == modified_bounds_.end()) {
00095 Bounds new_bounds(lb, ub);
00096 const Bounds& initial = (*initial_bounds_)[var_index];
00097 new_bounds.Intersect(initial);
00098 if (!new_bounds.IsEqual(initial)) {
00099 modified_bounds_.insert(make_pair(var_index, new_bounds));
00100 }
00101 } else {
00102 it->second.Intersect(lb, ub);
00103 }
00104 }
00105
00106 void Clear() {
00107 modified_bounds_.clear();
00108 }
00109
00110 const hash_map<int, Bounds>& modified_bounds() const {
00111 return modified_bounds_;
00112 }
00113
00114 vector<Bounds>* initial_bounds() const { return initial_bounds_; }
00115
00116 void Apply() {
00117 for (hash_map<int, Bounds>::const_iterator it = modified_bounds_.begin();
00118 it != modified_bounds_.end();
00119 ++it) {
00120 (*initial_bounds_)[it->first] = it->second;
00121 }
00122 }
00123
00124 private:
00125 vector<Bounds>* initial_bounds_;
00126 hash_map<int, Bounds> modified_bounds_;
00127 };
00128
00129
00130
00131 class ArithmeticConstraint {
00132 public:
00133 virtual ~ArithmeticConstraint() {}
00134
00135 const vector<int>& vars() const { return vars_; }
00136
00137 virtual bool Propagate(BoundsStore* const store) = 0;
00138 virtual void Replace(int to_replace, int var, int64 offset) = 0;
00139 virtual bool Deduce(ArithmeticPropagator* const propagator) const = 0;
00140 virtual string DebugString() const = 0;
00141 private:
00142 const vector<int> vars_;
00143 };
00144
00145
00146
00147 class ArithmeticPropagator : PropagationBaseObject {
00148 public:
00149 ArithmeticPropagator(Solver* const solver, Demon* const demon)
00150 : PropagationBaseObject(solver), demon_(demon) {}
00151
00152 void ReduceProblem() {
00153 for (int constraint_index = 0;
00154 constraint_index < constraints_.size();
00155 ++constraint_index) {
00156 if (constraints_[constraint_index]->Deduce(this)) {
00157 protected_constraints_.insert(constraint_index);
00158 }
00159 }
00160 scoped_ptr<Callback3<int, int, int64> > hook(
00161 NewPermanentCallback(this,
00162 &ArithmeticPropagator::ProcessOneSubstitution));
00163 substitution_map_.ProcessAllSubstitutions(hook.get());
00164 }
00165
00166 void Post() {
00167 for (int constraint_index = 0;
00168 constraint_index < constraints_.size();
00169 ++constraint_index) {
00170 const vector<int>& vars = constraints_[constraint_index]->vars();
00171 for (int var_index = 0; var_index < vars.size(); ++var_index) {
00172 dependencies_[vars[var_index]].push_back(constraint_index);
00173 }
00174 }
00175 }
00176
00177 void InitialPropagate() {
00178
00179 }
00180
00181 void Update(int var_index) {
00182 Enqueue(demon_);
00183 }
00184
00185 void AddConstraint(ArithmeticConstraint* const ct) {
00186 constraints_.push_back(ct);
00187 }
00188
00189 void AddVariable(int64 lb, int64 ub) {
00190 bounds_.push_back(Bounds(lb, ub));
00191 }
00192
00193 const vector<IntVar*> vars() const { return vars_; }
00194
00195 int VarIndex(IntVar* const var) {
00196 hash_map<IntVar*, int>::const_iterator it = var_map_.find(var);
00197 if (it == var_map_.end()) {
00198 const int index = var_map_.size();
00199 var_map_[var] = index;
00200 return index;
00201 } else {
00202 return it->second;
00203 }
00204 }
00205
00206 void AddSubstitution(int left_var, int right_var, int64 right_offset) {
00207 substitution_map_.AddSubstitution(left_var, right_var, right_offset);
00208 }
00209
00210 void AddNewBounds(int var_index, int64 lb, int64 ub) {
00211 bounds_[var_index].Intersect(lb, ub);
00212 }
00213
00214 void ProcessOneSubstitution(int left_var, int right_var, int64 right_offset) {
00215 for (int constraint_index = 0;
00216 constraint_index < constraints_.size();
00217 ++constraint_index) {
00218 if (!ContainsKey(protected_constraints_, constraint_index)) {
00219 ArithmeticConstraint* const constraint = constraints_[constraint_index];
00220 constraint->Replace(left_var, right_var, right_offset);
00221 }
00222 }
00223 }
00224
00225 void PrintModel() {
00226 LOG(INFO) << "Vars:";
00227 for (int i = 0; i < bounds_.size(); ++i) {
00228 LOG(INFO) << " var<" << i << "> = [" << bounds_[i].lb
00229 << " .. " << bounds_[i].ub << "]";
00230 }
00231 LOG(INFO) << "Constraints";
00232 for (int i = 0; i < constraints_.size(); ++i) {
00233 LOG(INFO) << " " << constraints_[i]->DebugString();
00234 }
00235 }
00236 private:
00237 Demon* const demon_;
00238 vector<IntVar*> vars_;
00239 hash_map<IntVar*, int> var_map_;
00240 vector<ArithmeticConstraint*> constraints_;
00241 vector<Bounds> bounds_;
00242 vector<vector<int> > dependencies_;
00243 SubstitutionMap substitution_map_;
00244 hash_set<int> protected_constraints_;
00245 };
00246
00247
00248
00249 class RowConstraint : public ArithmeticConstraint {
00250 public:
00251 RowConstraint(int64 lb, int64 ub) : lb_(lb), ub_(ub) {}
00252 virtual ~RowConstraint() {}
00253
00254 void AddTerm(int var_index, int64 coefficient) {
00255
00256 coefficients_[var_index] = coefficient;
00257 }
00258
00259 virtual bool Propagate(BoundsStore* const store) {
00260 return true;
00261 }
00262
00263 virtual void Replace(int to_replace, int var, int64 offset) {
00264 hash_map<int, int64>::iterator find_other = coefficients_.find(to_replace);
00265 if (find_other != coefficients_.end()) {
00266 hash_map<int, int64>::iterator find_var = coefficients_.find(var);
00267 const int64 other_coefficient = find_other->second;
00268 if (lb_ != kint64min) {
00269 lb_ += other_coefficient * offset;
00270 }
00271 if (ub_ != kint64max) {
00272 ub_ += other_coefficient * offset;
00273 }
00274 coefficients_.erase(find_other);
00275 if (find_var == coefficients_.end()) {
00276 coefficients_[var] = other_coefficient;
00277 } else {
00278 find_var->second += other_coefficient;
00279 if (find_var->second == 0) {
00280 coefficients_.erase(find_var);
00281 }
00282 }
00283 }
00284 }
00285
00286 virtual bool Deduce(ArithmeticPropagator* const propagator) const {
00287
00288 if (lb_ == ub_ && coefficients_.size() == 2) {
00289 hash_map<int, int64>::const_iterator it = coefficients_.begin();
00290 const int var1 = it->first;
00291 const int64 coeff1 = it->second;
00292 ++it;
00293 const int var2 = it->first;
00294 const int64 coeff2 = it->second;
00295 ++it;
00296 CHECK(it == coefficients_.end());
00297 if (coeff1 == 1 && coeff2 == -1) {
00298 propagator->AddSubstitution(var1, var2, lb_);
00299 return true;
00300 } else if (coeff1 == -1 && coeff2 && 1) {
00301 propagator->AddSubstitution(var2, var1, lb_);
00302 return true;
00303 }
00304 }
00305 return false;
00306 }
00307
00308 virtual string DebugString() const {
00309 string output = "(";
00310 bool first = true;
00311 for (hash_map<int, int64>::const_iterator it = coefficients_.begin();
00312 it != coefficients_.end();
00313 ++it) {
00314 if (it->second != 0) {
00315 if (first) {
00316 first = false;
00317 if (it->second == 1) {
00318 output += StringPrintf("var<%d>", it->first);
00319 } else if (it->second == -1) {
00320 output += StringPrintf("-var<%d>", it->first);
00321 } else {
00322 output += StringPrintf("%lld*var<%d>", it->second, it->first);
00323 }
00324 } else if (it->second == 1) {
00325 output += StringPrintf(" + var<%d>", it->first);
00326 } else if (it->second == -1) {
00327 output += StringPrintf(" - var<%d>", it->first);
00328 } else if (it->second > 0) {
00329 output += StringPrintf(" + %lld*var<%d>", it->second, it->first);
00330 } else {
00331 output += StringPrintf(" - %lld*var<%d>", -it->second, it->first);
00332 }
00333 }
00334 }
00335 if (lb_ == ub_) {
00336 output += StringPrintf(" == %lld)", ub_);
00337 } else if (lb_ == kint64min) {
00338 output += StringPrintf(" <= %lld)", ub_);
00339 } else if (ub_ == kint64max) {
00340 output += StringPrintf(" >= %lld)", lb_);
00341 } else {
00342 output += StringPrintf(" in [%lld .. %lld])", lb_, ub_);
00343 }
00344 return output;
00345 }
00346 private:
00347 hash_map<int, int64> coefficients_;
00348 int64 lb_;
00349 int64 ub_;
00350 };
00351
00352 class OrConstraint : public ArithmeticConstraint {
00353 public:
00354 OrConstraint(ArithmeticConstraint* const left,
00355 ArithmeticConstraint* const right)
00356 : left_(left), right_(right) {}
00357
00358 virtual ~OrConstraint() {}
00359
00360 virtual bool Propagate(BoundsStore* const store) {
00361 return true;
00362 }
00363
00364 virtual void Replace(int to_replace, int var, int64 offset) {
00365 left_->Replace(to_replace, var, offset);
00366 right_->Replace(to_replace, var, offset);
00367 }
00368
00369 virtual bool Deduce(ArithmeticPropagator* const propagator) const {
00370 return false;
00371 }
00372
00373 virtual string DebugString() const {
00374 return StringPrintf("Or(%s, %s)",
00375 left_->DebugString().c_str(),
00376 right_->DebugString().c_str());
00377 }
00378 private:
00379 ArithmeticConstraint* const left_;
00380 ArithmeticConstraint* const right_;
00381 };
00382
00383
00384
00385 GlobalArithmeticConstraint::GlobalArithmeticConstraint(Solver* const solver)
00386 : Constraint(solver),
00387 propagator_(NULL) {
00388 propagator_.reset(new ArithmeticPropagator(
00389 solver,
00390 solver->MakeDelayedConstraintInitialPropagateCallback(this)));
00391 }
00392 GlobalArithmeticConstraint::~GlobalArithmeticConstraint() {
00393 STLDeleteElements(&constraints_);
00394 }
00395
00396 void GlobalArithmeticConstraint::Post() {
00397 const vector<IntVar*>& vars = propagator_->vars();
00398 for (int var_index = 0; var_index < vars.size(); ++var_index) {
00399 Demon* const demon =
00400 MakeConstraintDemon1(solver(),
00401 this,
00402 &GlobalArithmeticConstraint::Update,
00403 "Update",
00404 var_index);
00405 vars[var_index]->WhenRange(demon);
00406 }
00407 LOG(INFO) << "----- Before reduction -----";
00408 propagator_->PrintModel();
00409 LOG(INFO) << "----- After reduction -----";
00410 propagator_->ReduceProblem();
00411 propagator_->PrintModel();
00412 LOG(INFO) << "---------------------------";
00413 propagator_->Post();
00414 }
00415
00416 void GlobalArithmeticConstraint::InitialPropagate() {
00417 propagator_->InitialPropagate();
00418 }
00419
00420 void GlobalArithmeticConstraint::Update(int var_index) {
00421 propagator_->Update(var_index);
00422 }
00423
00424 ConstraintRef GlobalArithmeticConstraint::MakeScalProdGreaterOrEqualConstant(
00425 const vector<IntVar*> vars,
00426 const vector<int64> coefficients,
00427 int64 constant) {
00428 RowConstraint* const constraint = new RowConstraint(constant, kint64max);
00429 for (int index = 0; index < vars.size(); ++index) {
00430 constraint->AddTerm(VarIndex(vars[index]), coefficients[index]);
00431 }
00432 return Store(constraint);
00433 }
00434
00435 ConstraintRef GlobalArithmeticConstraint::MakeScalProdLessOrEqualConstant(
00436 const vector<IntVar*> vars,
00437 const vector<int64> coefficients,
00438 int64 constant) {
00439 RowConstraint* const constraint = new RowConstraint(kint64min, constant);
00440 for (int index = 0; index < vars.size(); ++index) {
00441 constraint->AddTerm(VarIndex(vars[index]), coefficients[index]);
00442 }
00443 return Store(constraint);
00444 }
00445
00446 ConstraintRef GlobalArithmeticConstraint::MakeScalProdEqualConstant(
00447 const vector<IntVar*> vars,
00448 const vector<int64> coefficients,
00449 int64 constant) {
00450 RowConstraint* const constraint = new RowConstraint(constant, constant);
00451 for (int index = 0; index < vars.size(); ++index) {
00452 constraint->AddTerm(VarIndex(vars[index]), coefficients[index]);
00453 }
00454 return Store(constraint);
00455 }
00456
00457 ConstraintRef GlobalArithmeticConstraint::MakeSumGreaterOrEqualConstant(
00458 const vector<IntVar*> vars,
00459 int64 constant) {
00460 RowConstraint* const constraint = new RowConstraint(constant, kint64max);
00461 for (int index = 0; index < vars.size(); ++index) {
00462 constraint->AddTerm(VarIndex(vars[index]), 1);
00463 }
00464 return Store(constraint);
00465 }
00466
00467 ConstraintRef GlobalArithmeticConstraint::MakeSumLessOrEqualConstant(
00468 const vector<IntVar*> vars, int64 constant) {
00469 RowConstraint* const constraint = new RowConstraint(kint64min, constant);
00470 for (int index = 0; index < vars.size(); ++index) {
00471 constraint->AddTerm(VarIndex(vars[index]), 1);
00472 }
00473 return Store(constraint);
00474 }
00475
00476 ConstraintRef GlobalArithmeticConstraint::MakeSumEqualConstant(
00477 const vector<IntVar*> vars, int64 constant) {
00478 RowConstraint* const constraint = new RowConstraint(constant, constant);
00479 for (int index = 0; index < vars.size(); ++index) {
00480 constraint->AddTerm(VarIndex(vars[index]), 1);
00481 }
00482 return Store(constraint);
00483 }
00484
00485 ConstraintRef GlobalArithmeticConstraint::MakeRowConstraint(
00486 int64 lb,
00487 const vector<IntVar*> vars,
00488 const vector<int64> coefficients,
00489 int64 ub) {
00490 RowConstraint* const constraint = new RowConstraint(lb, ub);
00491 for (int index = 0; index < vars.size(); ++index) {
00492 constraint->AddTerm(VarIndex(vars[index]), coefficients[index]);
00493 }
00494 return Store(constraint);
00495 }
00496
00497 ConstraintRef GlobalArithmeticConstraint::MakeRowConstraint(int64 lb,
00498 IntVar* const v1,
00499 int64 coeff1,
00500 int64 ub) {
00501 RowConstraint* const constraint = new RowConstraint(lb, ub);
00502 constraint->AddTerm(VarIndex(v1), coeff1);
00503 return Store(constraint);
00504 }
00505
00506 ConstraintRef GlobalArithmeticConstraint::MakeRowConstraint(int64 lb,
00507 IntVar* const v1,
00508 int64 coeff1,
00509 IntVar* const v2,
00510 int64 coeff2,
00511 int64 ub) {
00512 RowConstraint* const constraint = new RowConstraint(lb, ub);
00513 constraint->AddTerm(VarIndex(v1), coeff1);
00514 constraint->AddTerm(VarIndex(v2), coeff2);
00515 return Store(constraint);
00516 }
00517
00518 ConstraintRef GlobalArithmeticConstraint::MakeRowConstraint(int64 lb,
00519 IntVar* const v1,
00520 int64 coeff1,
00521 IntVar* const v2,
00522 int64 coeff2,
00523 IntVar* const v3,
00524 int64 coeff3,
00525 int64 ub) {
00526 RowConstraint* const constraint = new RowConstraint(lb, ub);
00527 constraint->AddTerm(VarIndex(v1), coeff1);
00528 constraint->AddTerm(VarIndex(v2), coeff2);
00529 constraint->AddTerm(VarIndex(v3), coeff3);
00530 return Store(constraint);
00531 }
00532
00533 ConstraintRef GlobalArithmeticConstraint::MakeRowConstraint(int64 lb,
00534 IntVar* const v1,
00535 int64 coeff1,
00536 IntVar* const v2,
00537 int64 coeff2,
00538 IntVar* const v3,
00539 int64 coeff3,
00540 IntVar* const v4,
00541 int64 coeff4,
00542 int64 ub) {
00543 RowConstraint* const constraint = new RowConstraint(lb, ub);
00544 constraint->AddTerm(VarIndex(v1), coeff1);
00545 constraint->AddTerm(VarIndex(v2), coeff2);
00546 constraint->AddTerm(VarIndex(v3), coeff3);
00547 constraint->AddTerm(VarIndex(v4), coeff4);
00548 return Store(constraint);
00549 }
00550
00551 ConstraintRef GlobalArithmeticConstraint::MakeOrConstraint(
00552 ConstraintRef left_ref,
00553 ConstraintRef right_ref) {
00554 OrConstraint* const constraint =
00555 new OrConstraint(constraints_[left_ref.index()],
00556 constraints_[right_ref.index()]);
00557 return Store(constraint);
00558 }
00559
00560 void GlobalArithmeticConstraint::Add(ConstraintRef ref) {
00561 propagator_->AddConstraint(constraints_[ref.index()]);
00562 }
00563
00564 int GlobalArithmeticConstraint::VarIndex(IntVar* const var) {
00565 hash_map<IntVar*, int>::const_iterator it = var_indices_.find(var);
00566 if (it == var_indices_.end()) {
00567 const int new_index = var_indices_.size();
00568 var_indices_.insert(make_pair(var, new_index));
00569 propagator_->AddVariable(var->Min(), var->Max());
00570 return new_index;
00571 } else {
00572 return it->second;
00573 }
00574 }
00575
00576 ConstraintRef GlobalArithmeticConstraint::Store(
00577 ArithmeticConstraint* const constraint) {
00578 const int constraint_index = constraints_.size();
00579 constraints_.push_back(constraint);
00580 return ConstraintRef(constraint_index);
00581 }
00582 }