#include <iostream>
#include <vector>
#include <unordered_map>
#include <fmt/core.h>

using namespace fmt;

struct Point {
  int x, y;

  bool operator==(const Point& other) const {
    return other.x == x && other.y == y;
  }
};

struct Object {
  Point position;
  // Add other object data as needed
};

struct PointHash {
  size_t operator()(const Point& p) const {
    return std::hash<int>()(p.x) ^ std::hash<int>()(p.y);
  }
};

class SpatialHashTable {
  public:
    SpatialHashTable() {}

    void insert(Object* obj) {
      table[obj->position].push_back(obj);
    }

    void remove(Object* obj) {
      table.erase(obj->position);
    }

    std::vector<Object*> getNearbyObjects(Point position) {
      std::vector<Object*> result;
      Point cell = position;

      // Check the current cell and its 8 neighbors
      for (int x = cell.x - 1; x <= cell.x + 1; x++) {
        for (int y = cell.y - 1; y <= cell.y + 1; y++) {
          Point neighborCell = {x, y};
          auto it = table.find(neighborCell);
          if (it != table.end()) {
            result.insert(result.end(), it->second.begin(), it->second.end());
          }
        }
      }

      return result;
    }

  private:
    std::unordered_map<Point, std::vector<Object*>, PointHash> table;
};



int main() {
  SpatialHashTable hashTable;
  Object obj1 = {{5, 5}};
  Object obj2 = {{15, 15}};
  Object bomb = {{25, 25}};

  hashTable.insert(&obj1);
  hashTable.insert(&obj2);
  hashTable.insert(&bomb);

  std::vector<Object*> nearby = hashTable.getNearbyObjects({24, 24});

  for (Object* obj : nearby) {
    println("{},{}", obj->position.x, obj->position.y);
  }

  println("AFTER MOVE");

  // now attempt a move
  hashTable.remove(&bomb);
  bomb.position.x += 1;
  bomb.position.y += 1;
  hashTable.insert(&bomb);

  nearby = hashTable.getNearbyObjects({24, 24});

  for (Object* obj : nearby) {
    println("{},{}", obj->position.x, obj->position.y);
  }

  println("AFTER MOVE BACK");

  // now attempt a move
  hashTable.remove(&bomb);
  bomb.position.x -= 3;
  bomb.position.y -= 2;
  hashTable.insert(&bomb);

  nearby = hashTable.getNearbyObjects({24, 24});

  for (Object* obj : nearby) {
    println("{},{}", obj->position.x, obj->position.y);
  }

  return 0;
}