18 #include <unordered_map>
29 template <
class state>
37 template <
class state>
47 template <
class state>
61 template <
class state>
70 template <
class state,
class action,
class environment>
84 void GetPath(environment *env,
const state& from,
const state& to, std::vector<state> &thePath);
85 void GetPath(environment *,
const state& ,
const state& , std::vector<action> & ) { assert(
false); };
93 for (
typename LearnedStateData::const_iterator it =
stateData.begin(); it !=
stateData.end(); it++)
98 std::cout << s->
theState <<
" is dead but has parents and/or children " << std::endl;
100 for (
unsigned int x = 0; x < s->
children->size(); x++)
113 std::cout << s->
theState <<
" lists " << s->
children->at(x) <<
" as child, but the child doesn't list the parent" << std::endl;
128 uint64_t hash =
m_pEnv->GetStateHash(where);
131 if (
verbose) std::cout <<
"Killing " << where << std::endl;
132 for (
unsigned int x = 0; x < theState.
parents->size(); x++)
137 for (
unsigned int x = 0; x < theState.
children->size(); x++)
147 theState.
dead =
true;
155 uint64_t hash =
m_pEnv->GetStateHash(where);
156 typename LearnedStateData::iterator it =
stateData.find(hash);
162 return (*it).second.dead;
173 std::cout <<
"ABORT! trying to add dead child " << child <<
" to parent " << parent << std::endl;
176 uint64_t hash =
m_pEnv->GetStateHash(child);
179 for (
unsigned int x = 0; x < theState.
parents->size(); x++)
181 if (theState.
parents->at(x) == parent)
184 theState.
parents->push_back(parent);
189 uint64_t hash =
m_pEnv->GetStateHash(child);
192 for (
unsigned int x = 0; x < theState.
parents->size(); x++)
194 if (theState.
parents->at(x) == parent)
197 std::cout << parent <<
" was parent of " << child << std::endl;
210 void AddChild(
const state &parent,
const state &child)
212 uint64_t hash =
m_pEnv->GetStateHash(parent);
215 for (
unsigned int x = 0; x < theState.
children->size(); x++)
217 if (theState.
children->at(x) == child)
220 theState.
children->push_back(child);
225 uint64_t hash =
m_pEnv->GetStateHash(parent);
228 for (
unsigned int x = 0; x < theState.
children->size(); x++)
230 if (theState.
children->at(x) == child)
232 if (
verbose) std::cout << child <<
" was child of " << parent << std::endl;
253 void SetGCost(environment *env,
const state &where,
double val)
255 uint64_t hash = env->GetStateHash(where);
259 theState.
parents =
new std::vector<state>();
261 theState.
children =
new std::vector<state>();
263 for (
unsigned int x = 0; x < theState.
parents->size(); x++)
268 for (
unsigned int x = 0; x < theState.
children->size(); x++)
274 if (
verbose) std::cout <<
"-->GCost of " << where <<
" setting to " << val << std::endl;
278 theState.
gCost = val;
280 theState.
dead =
false;
283 double FCost(
const state &where,
const state &goal)
289 double GCost(environment *env,
const state &where)
291 uint64_t hash = env->GetStateHash(where);
293 typename LearnedStateData::iterator it =
stateData.find(hash);
297 return (*it).second.gCost;
305 void SetHCost(environment *env,
const state &where,
const state &to,
double val)
307 double tmp = val-env->HCost(where, to);
308 if (tmp < 0) tmp = 0;
310 stateData[env->GetStateHash(where)].hCost = tmp;
313 double HCost(environment *env,
const state &from,
const state &to)
const
315 auto val =
stateData.find(env->GetStateHash(from));
318 return val->second.hCost+env->HCost(from, to);
320 return env->HCost(from, to);
330 double HCost(
const state &from,
const state &to)
const
335 uint64_t hash =
m_pEnv->GetStateHash(where);
338 if (where == goal)
return;
340 if (theState.
dead)
return;
341 if (
verbose) std::cout <<
"Trying to kill: " << where << std::endl;
342 for (
unsigned int x = 0; x < theState.
children->size(); x++)
349 printf(
"---But the state below is dead!\n");
350 std::cout <<
" Thwarted by " << theState.
children->at(x) << std::endl;
382 for (
int x = 0; x <
stateData[
m_pEnv->GetStateHash(where)].children->size(); x++)
421 std::vector<state> succ;
422 m_pEnv->GetSuccessors(where, succ);
423 for (
unsigned int x = 0; x < succ.size(); x++)
435 if (
verbose) std::cout << where <<
" has " << count <<
" successors" << std::endl;
441 std::vector<state> succ;
442 m_pEnv->GetSuccessors(child, succ);
443 for (
unsigned int x = 0; x < succ.size(); x++)
445 if (succ[x] == parent)
continue;
463 void OpenGLDraw(
const environment *env)
const;
466 typedef std::unordered_map<uint64_t, bool, Hash64 >
ClosedList;
467 void ExtractBestPath(environment *env,
const state &from,
const state &to, std::vector<state> &thePath);
468 void MakeTrappedMove(environment *env,
const state &from, std::vector<state> &thePath);
470 void ExpandLSS(
const state &from,
const state &to, std::vector<state> &thePath);
472 void DoHCostLearning(environment *env,
const state &from,
const state &to);
473 void PropagateGCosts(
const state &next,
const state &to,
bool alsoExpand);
490 template <
class state,
class action,
class environment>
497 nodesExpanded = nodesTouched = 0;
502 assert(gCostQueue.empty());
506 if (GCost(from) == DBL_MAX)
507 SetGCost(env, from, 0);
511 MakeTrappedMove(env, from, thePath);
515 ExpandLSS(from, to, thePath);
521 if (thePath.size() != 0)
524 DoHCostLearning(env, from, to);
528 MarkDeadRedundant(env, to);
530 ExtractBestPath(env, from, to, thePath);
533 if (
verbose) std::cout <<
"FLRTA* heading towards " << thePath[0] <<
" with h-value " << HCost(env, thePath[0], to) << std::endl;
538 template <
class state,
class action,
class environment>
542 unsigned int cnt = 0;
543 for (; cnt < aoc.OpenSize(); cnt++)
545 if (!IsDead(aoc.Lookat(aoc.GetOpenItem(cnt)).data) &&
546 !(aoc.Lookat(aoc.GetOpenItem(cnt)).data == from))
548 best = aoc.Lookat(aoc.GetOpenItem(cnt)).data;
553 if (cnt == aoc.OpenSize())
555 MakeTrappedMove(env, from, thePath);
560 for (; cnt < aoc.OpenSize(); cnt++)
563 if (IsDead(data.
data))
565 double tmp = fWeight;
568 if (followLocalGCost)
570 if (
fgreater(bestG+fWeight*HCost(best, to),
571 data.
g+fWeight*HCost(data.
data, to)) &&
572 !(data.
data == from))
577 else if (
fequal(bestG+fWeight*HCost(best, to), data.
g+fWeight*HCost(data.
data, to)) &&
579 !(data.
data == from))
586 if (
fgreater(GCost(best)+fWeight*HCost(best, to),
587 GCost(data.
data)+fWeight*HCost(data.
data, to)) &&
588 !(data.
data == from))
592 else if (
fequal(GCost(best)+fWeight*HCost(best, to), GCost(data.
data)+fWeight*HCost(data.
data, to)) &&
594 !(data.
data == from))
602 if (
verbose) std::cout <<
"Moving towards " << best <<
" cost " << GCost(best) << std::endl;
604 aoc.Lookup(env->GetStateHash(best),
node);
606 thePath.push_back(aoc.Lookup(
node).data);
608 }
while (aoc.Lookup(
node).parentID !=
node);
609 thePath.push_back(aoc.Lookup(
node).data);
612 template <
class state,
class action,
class environment>
615 std::vector<state> succ;
616 env->GetSuccessors(from, succ);
620 for (
unsigned int x = 0; x < succ.size(); x++)
622 if ((back == -1) && (GCost(succ[x]) < GCost(from)))
627 else if ((back != -1) && (GCost(succ[x]) < GCost(succ[back])))
635 std::cout <<
"No successors of " << from <<
" with smaller g-cost than " << GCost(from) << std::endl;
636 for (
unsigned int x = 0; x < succ.size(); x++)
637 std::cout << succ[x] <<
" has cost " << GCost(succ[x]) << std::endl;
638 assert(!
"Invalid environment; apparently cannot reach all predecessors of a state");
641 thePath.push_back(succ[back]);
642 thePath.push_back(from);
647 template <
class state,
class action,
class environment>
650 if (
verbose) std::cout <<
"Expanding LSS" << std::endl;
653 aoc.AddOpenNode(from, m_pEnv->GetStateHash(from), 0.0, 0.0);
654 std::vector<state> tempDead;
656 for (
int x = 0; x < nodeExpansionLimit; x++)
658 if (aoc.OpenSize() == 0)
661 uint64_t next = aoc.Close();
662 if (
verbose) std::cout <<
"Next in LSS: " << aoc.Lookat(next).data << std::endl;
665 if (IsDead(aoc.Lookat(next).data))
670 if (
verbose) std::cout << aoc.Lookat(next).data <<
" is dead; left on closed (LSS)" << std::endl;
676 state val = aoc.Lookat(next).data;
677 PropagateGCosts(val, to,
true);
681 if (aoc.Lookat(next).data == to)
683 uint64_t
node = next;
686 thePath.push_back(aoc.Lookup(
node).data);
688 }
while (aoc.Lookup(
node).parentID !=
node);
689 thePath.push_back(aoc.Lookup(
node).data);
695 for (
unsigned int x = 0; x < aoc.OpenSize(); x++)
697 state s = aoc.Lookat(aoc.GetOpenItem(x)).data;
699 PropagateGCosts(s, to,
false);
705 template <
class state,
class action,
class environment>
708 if (
verbose) std::cout <<
"=Propagating from: " << next << (alsoExpand?
" and expanding":
" not expanding") << std::endl;
712 std::vector<state> *neighbors = vc.
getItem();;
713 m_pEnv->GetSuccessors(next, *neighbors);
717 dataLocation pLoc = aoc.Lookup(m_pEnv->GetStateHash(next), parentKey);
719 if (
verbose) std::cout << GCost(next) <<
" gcost in " <<
721 for (
unsigned int x = 0; x < neighbors->size(); x++)
724 double edgeCost = m_pEnv->GCost(next, neighbors->at(x));
727 dataLocation cLoc = aoc.Lookup(m_pEnv->GetStateHash(neighbors->at(x)), childKey);
735 if (
verbose) std::cout <<
"Adding " << neighbors->at(x) <<
" to open with f:" <<
736 aoc.Lookat(parentKey).g+edgeCost + HCost(neighbors->at(x), to) << std::endl;
737 aoc.AddOpenNode(neighbors->at(x), m_pEnv->GetStateHash(neighbors->at(x)),
738 aoc.Lookat(parentKey).g+edgeCost,
740 HCost(neighbors->at(x), to), parentKey);
745 if (
fless(aoc.Lookup(parentKey).g+edgeCost, aoc.Lookup(childKey).g))
747 if (
verbose) std::cout <<
"Updating " << neighbors->at(x) <<
" on open" << std::endl;
748 aoc.Lookup(childKey).parentID = parentKey;
749 aoc.Lookup(childKey).g = aoc.Lookup(parentKey).g+edgeCost;
750 aoc.KeyChanged(childKey);
755 if (
fless(aoc.Lookup(parentKey).g+edgeCost, aoc.Lookup(childKey).g))
757 if (
verbose) std::cout <<
"Reopening " << neighbors->at(x) << std::endl;
758 aoc.Lookup(childKey).parentID = parentKey;
759 aoc.Lookup(childKey).g = aoc.Lookup(parentKey).g+edgeCost;
760 aoc.Reopen(childKey);
767 assert(GCost(next) != DBL_MAX);
769 if (!IsDead(neighbors->at(x)) &&
fequal(GCost(next)+edgeCost, GCost(neighbors->at(x))))
771 AddParent(next, neighbors->at(x));
772 AddChild(next, neighbors->at(x));
774 if (
fless(GCost(next)+edgeCost, GCost(neighbors->at(x))))
776 if (
verbose) std::cout <<
"Updating " << neighbors->at(x) <<
" from " << GCost(neighbors->at(x)) <<
777 " to " << GCost(next) <<
"(" << next <<
") + " << edgeCost <<
" = " << GCost(next)+edgeCost << std::endl;
778 if (IsDead(neighbors->at(x)) && (cLoc ==
kClosedList))
783 SetGCost(m_pEnv, neighbors->at(x), GCost(next)+edgeCost);
785 AddParent(next, neighbors->at(x));
786 AddChild(next, neighbors->at(x));
787 PropagateGCosts(neighbors->at(x), to,
true);
790 SetGCost(m_pEnv, neighbors->at(x), GCost(next)+edgeCost);
792 AddParent(next, neighbors->at(x));
793 AddChild(next, neighbors->at(x));
796 PropagateGCosts(neighbors->at(x), to,
false);
802 if (
fless(edgeCost+GCost(neighbors->at(x)), GCost(next)) && !IsDead(neighbors->at(x)))
806 if (
verbose) std::cout <<
"[Recursing to] Update " << next <<
" from " << GCost(next) <<
807 " to " << GCost(neighbors->at(x)) <<
"(" << neighbors->at(x) <<
") + " << edgeCost <<
" = " << GCost(neighbors->at(x))+edgeCost << std::endl;
808 SetGCost(m_pEnv, next, GCost(neighbors->at(x))+edgeCost);
810 if (!IsDead(neighbors->at(x)))
812 AddParent(neighbors->at(x), next);
813 AddChild(neighbors->at(x), next);
820 vc.returnItem(neighbors);
821 if (
verbose) std::cout <<
"=Done Propagating from: " << next << std::endl;
824 template <
class state,
class action,
class environment>
832 for (
unsigned int x = 0; x <
openSize; x++)
838 if (
verbose) std::cout <<std::endl<<
">>>Preparing state: " << data.
data <<
" g: " << GCost(data.
data) << std::endl;
843 state first = q.top().theState;
848 state s = q.top().theState;
852 if (
fgreater(GCost(s)+HCost(s, goal), GCost(goal)))
854 if (
verbose) std::cout<<
"Marking " << GCost(s) <<
":" << HCost(s, goal) <<
" " << s <<
" as dead -- too far from goal: " << GCost(goal) << goal << std::endl;
861 template <
class state,
class action,
class environment>
868 std::vector<state> toKill;
870 std::vector<state> succ;
872 unsigned int openSize = aoc.OpenSize();
873 for (
unsigned int x = 0; x <
openSize; x++)
876 if (!IsDead(data.
data))
879 if (
verbose) std::cout <<
"Preparing border state: " << data.
data <<
" h: " << data.
h << std::endl;
887 state first = q.top().theState;
891 state s = q.top().theState;
892 if (
verbose) std::cout <<
"Propagating from " << s <<
" h: " << q.top().value <<
"/" << HCost(s, to) << std::endl;
895 env->GetSuccessors(s, succ);
898 double hCost = HCost(s, to);
900 for (
unsigned int x = 0; x < succ.size(); x++)
909 dataLocation pLoc = aoc.Lookup(env->GetStateHash(succ[x]), succKey);
915 double edgeCost = env->GCost(s, succ[x]);
916 succHCost = HCost(env, succ[x], to);
917 if (c[env->GetStateHash(succ[x])])
919 if (
verbose) std::cout << succ[x] <<
" updated before " << std::endl;
920 if (
fless(hCost + edgeCost, succHCost))
922 fAmountLearned -= succHCost-hCost-edgeCost;
923 if (
verbose) std::cout <<
"lowering cost to " << hCost + edgeCost;
925 SetHCost(env, succ[x], to, hCost + edgeCost);
933 if (
verbose) std::cout << succ[x] <<
" NOT updated before ";
936 if (
verbose) std::cout <<
"setting cost to " << hCost + edgeCost <<
" over " << succHCost << std::endl;
937 fAmountLearned += (edgeCost + hCost) - succHCost;
939 SetHCost(env, succ[x], to, hCost + edgeCost);
941 c[env->GetStateHash(succ[x])] =
true;
949 template <
class state,
class action,
class environment>
957 for (
typename LearnedStateData::const_iterator it =
stateData.begin(); it !=
stateData.end(); it++)
959 double thisState = (*it).second.hCost;
960 if (learned < thisState)
963 for (
typename LearnedStateData::const_iterator it =
stateData.begin(); it !=
stateData.end(); it++)
969 for (
size_t x = 0; x < (*it).second.children->size(); x++)
971 m_pEnv->GLDrawLine((*it).second.theState, (*it).second.children->at(x));
973 for (
size_t x = 0; x < (*it).second.parents->size(); x++)
975 m_pEnv->GLDrawLine((*it).second.theState, (*it).second.parents->at(x));
980 if ((*it).second.dead)
981 sprintf(str,
" %1.1f", (*it).second.gCost);
983 sprintf(str,
"%1.1f %1.1f", (*it).second.gCost, (*it).second.hCost+m_pEnv->HCost((*it).second.theState, theEnd));
984 e->SetColor(0.9, 0.9, 0.9, 1);
985 e->GLLabelState((*it).second.theState, str);
988 if ((*it).second.dead)
990 e->SetColor(0.0, 0.0+((
loc==
kOpenList)?0.5:0.0), 0.0, 1);
991 e->OpenGLDraw((*it).second.theState);
995 double r = (*it).second.hCost;
998 e->SetColor(0.5+0.5*r/learned, ((
loc==
kOpenList)?0.5:0.0), 0, 0.1+0.8*r/learned);
999 e->OpenGLDraw((*it).second.theState);
1003 e->SetColor(0.0, 0.5, 0.0, 1);
1004 e->OpenGLDraw((*it).second.theState);