13#include "mlir/IR/Builders.h"
14#include "llvm/ADT/TypeSwitch.h"
33 DenseSet<PrimitiveType> &primsAtLoc = getLeaf(loc);
34 PrimitiveType prim = loc.getPrimitiveType().getValue();
35 if (primsAtLoc.contains(prim))
37 primsAtLoc.insert(prim);
45 DenseSet<PrimitiveType> primsAtLoc = getLeaf(loc);
46 return primsAtLoc.contains(loc.getPrimitiveType().getValue());
50 return placements[loc.getX()][loc.getY()][loc.getNum()];
53void PrimitiveDB::foreach (
54 function_ref<
void(PhysLocationAttr)> callback)
const {
55 for (
const auto &x : placements)
56 for (
const auto &y : x.second)
57 for (
const auto &n : y.second)
58 for (
auto p : n.second)
59 callback(PhysLocationAttr::get(ctxt, PrimitiveTypeAttr::get(ctxt, p),
60 x.first, y.first, n.first));
73 : ctxt(topMod->getContext()), topMod(topMod), seeded(false) {
77 : ctxt(topMod->getContext()), topMod(topMod), seeded(false) {
79 seed.
foreach ([
this](PhysLocationAttr loc) { (void)
getLeaf(loc); });
87 PhysLocationAttr loc, StringRef subPath,
89 StringAttr subPathAttr;
91 subPathAttr = StringAttr::get(inst->getContext(), subPath);
93 OpBuilder(inst.getBody())
102 LocationVectorAttr locs,
104 PDRegPhysLocationOp locOp =
105 OpBuilder(inst.getBody())
106 .create<PDRegPhysLocationOp>(srcLoc, locs, FlatSymbolRefAttr());
107 for (PhysLocationAttr loc : locs.getLocs())
116 PhysLocationAttr loc) {
121 return op->emitOpError(
"Could not apply placement. Invalid location: ")
123 if (leaf->
locOp !=
nullptr)
124 return op->emitOpError(
"Could not apply placement ")
125 << loc <<
". Position already occupied by "
126 << cast<DynamicInstanceOp>(leaf->
locOp->getParentOp()).getPath();
135 StringRef subPath, Location srcLoc) {
136 StringAttr subPathAttr;
137 if (!subPath.empty())
138 subPathAttr = StringAttr::get(inst->getContext(), subPath);
139 PDPhysRegionOp regOp =
140 OpBuilder::atBlockEnd(&inst.getBody().front())
141 .create<PDPhysRegionOp>(srcLoc, FlatSymbolRefAttr::get(physregion),
142 subPathAttr, FlatSymbolRefAttr());
150 size_t numFailed = 0;
151 inst->walk([&](Operation *op) {
152 LogicalResult added = TypeSwitch<Operation *, LogicalResult>(op)
156 .Case([&](PDRegPhysLocationOp op) {
157 ArrayRef<PhysLocationAttr> locs =
158 op.getLocs().getLocs();
159 for (
auto loc : locs)
164 .Case([&](PDPhysRegionOp op) {
168 .Default([](Operation *op) {
return failure(); });
195 PhysLocationAttr newLoc) {
196 PhysLocationAttr from = locOp.getLoc();
199 locOp.setLocAttr(newLoc);
207 for (PhysLocationAttr loc : locOp.getLocs().getLocs())
217 LocationVectorAttr newLocs) {
218 ArrayRef<PhysLocationAttr> fromLocs = locOp.getLocs().getLocs();
222 for (
auto [from, to] :
llvm::zip(fromLocs, newLocs.getLocs())) {
233 for (
auto [from, to] :
llvm::zip(fromLocs, newLocs.getLocs())) {
247 locOp.setLocsAttr(newLocs);
252 PhysLocationAttr loc) {
253 PlacementCell *leaf =
getLeaf(loc);
254 assert(leaf &&
"Could not find op at location specified by op");
255 assert(leaf->locOp == op);
260 PhysLocationAttr from,
261 PhysLocationAttr to) {
268 if (!oldLeaf || !newLeaf)
271 if (oldLeaf->
locOp ==
nullptr)
272 return op.emitError(
"cannot move from a location not occupied by "
273 "specified op. Currently unoccupied");
274 if (oldLeaf->
locOp != op)
275 return op.emitError(
"cannot move from a location not occupied by "
276 "specified op. Currently occupied by ")
280 "cannot move to new location since location is occupied by ")
281 << cast<DynamicInstanceOp>(newLeaf->
locOp->getParentOp()).getPath();
286 PhysLocationAttr from, PhysLocationAttr to) {
288 "Call `movePlacementCheck` first to ensure that move is legal.");
299 auto innerMap =
placements[loc.getX()][loc.getY()][loc.getNum()];
300 auto instF = innerMap.find(loc.getPrimitiveType().getValue());
301 if (instF == innerMap.end())
303 if (!instF->getSecond().locOp)
305 return instF->getSecond().locOp;
310 uint64_t nearestToY) {
312 PhysLocationAttr nearest = {};
314 [&nearest, nearestToY](PhysLocationAttr loc, Operation *locOp) {
322 std::abs((int64_t)nearestToY - (int64_t)nearest.getY());
323 int64_t replDist = std::abs((int64_t)nearestToY - (int64_t)loc.getY());
324 if (replDist < curDist)
327 std::make_tuple(columnNum, columnNum, -1, -1), prim);
332 PrimitiveType primType = loc.getPrimitiveType().getValue();
336 return &nums[loc.getNum()][primType];
337 if (!nums.count(loc.getNum()))
341 if (primitives.count(primType) == 0)
343 return &primitives[primType];
348 function_ref<
void(PhysLocationAttr, DynInstDataOpInterface)> callback,
349 std::tuple<int64_t, int64_t, int64_t, int64_t> bounds,
350 std::optional<PrimitiveType> primType, std::optional<WalkOrder> walkOrder) {
351 uint64_t xmin = std::get<0>(bounds) < 0 ? 0 : std::get<0>(bounds);
352 uint64_t xmax = std::get<1>(bounds) < 0 ? std::numeric_limits<uint64_t>::max()
353 : (uint64_t)std::get<1>(bounds);
354 uint64_t ymin = std::get<2>(bounds) < 0 ? 0 : std::get<2>(bounds);
355 uint64_t ymax = std::get<3>(bounds) < 0 ? std::numeric_limits<uint64_t>::max()
356 : (uint64_t)std::get<3>(bounds);
363 auto maybeSort = [](
auto &container,
auto direction) {
364 if (!direction.has_value())
369 llvm::sort(container, [direction](
auto colA,
auto colB) {
371 return colA.first < colB.first;
373 return colA.first > colB.first;
378 SmallVector<std::pair<size_t, DimYMap>> cols(
placements.begin(),
380 maybeSort(cols, llvm::transformOptional(walkOrder,
381 [](
auto wo) {
return wo.columns; }));
382 for (
const auto &colF : cols) {
383 size_t x = colF.first;
384 if (x < xmin || x > xmax)
389 SmallVector<std::pair<size_t, DimNumMap>> rows(yMap.begin(), yMap.end());
390 maybeSort(rows, llvm::transformOptional(walkOrder,
391 [](
auto wo) {
return wo.rows; }));
392 for (
const auto &rowF : rows) {
393 size_t y = rowF.first;
394 if (y < ymin || y > ymax)
399 for (
auto &numF : numMap) {
400 size_t num = numF.getFirst();
404 for (
auto &devF : devMap) {
405 PrimitiveType devtype = devF.getFirst();
406 if (primType && devtype != *primType)
408 PlacementCell &inst = devF.getSecond();
411 PhysLocationAttr loc = PhysLocationAttr::get(
412 ctxt, PrimitiveTypeAttr::get(
ctxt, devtype), x, y, num);
413 callback(loc, inst.locOp);
422 function_ref<
void(PDPhysRegionOp)> callback) {
assert(baseType &&"element must be base type")
PlacementDB(MlirModule top, PrimitiveDB *seed)
bool isValidLocation(MlirAttribute loc)
PrimitiveDB(MlirContext ctxt)
bool addPrimitive(MlirAttribute locAndPrim)
LogicalResult movePlacementCheck(DynInstDataOpInterface op, PhysLocationAttr from, PhysLocationAttr to)
Check to make sure the move is going to succeed.
void removePlacement(PDPhysLocationOp)
Remove the placement from the DB and IR. Erases the op.
DynInstDataOpInterface getInstanceAt(PhysLocationAttr)
Lookup the instance at a particular location.
size_t addPlacements(DynamicInstanceOp inst)
Load the placements from inst.
LogicalResult movePlacement(PDPhysLocationOp, PhysLocationAttr)
Move a placement location to a new location.
DenseMap< size_t, DimNumMap > DimYMap
PDPhysLocationOp place(DynamicInstanceOp inst, PhysLocationAttr, StringRef subpath, Location srcLoc)
Assign an instance to a primitive.
size_t addDesignPlacements()
Load the database from the IR.
PlacementCell * getLeaf(PhysLocationAttr)
Get the leaf node.
DenseMap< PrimitiveType, PlacementCell > DimDevType
PlacementDB(mlir::ModuleOp topMod)
Create a placement db containing all the placements in 'topMod'.
void walkRegionPlacements(function_ref< void(PDPhysRegionOp)>)
Walk the region placement information.
PhysLocationAttr getNearestFreeInColumn(PrimitiveType prim, uint64_t column, uint64_t nearestToY)
Find the nearest unoccupied primitive location to 'nearestToY' in 'column'.
PDPhysRegionOp placeIn(DynamicInstanceOp inst, DeclPhysicalRegionOp, StringRef subPath, Location srcLoc)
Assign an operation to a physical region. Return false on failure.
void walkPlacements(function_ref< void(PhysLocationAttr, DynInstDataOpInterface)>, std::tuple< int64_t, int64_t, int64_t, int64_t > bounds=std::make_tuple(-1, -1, -1, -1), std::optional< PrimitiveType > primType={}, std::optional< WalkOrder >={})
Walk the placement information in some sort of reasonable order.
DenseMap< size_t, DimDevType > DimNumMap
LogicalResult insertPlacement(DynInstDataOpInterface op, PhysLocationAttr)
RegionPlacements regionPlacements
A data structure to contain locations of the primitives on the device.
void foreach(function_ref< void(PhysLocationAttr)> callback) const
Iterate over all the primitive locations, executing 'callback' on each one.
DenseSet< PrimitiveType > DimPrimitiveType
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
DynInstDataOpInterface locOp