13#include "mlir/IR/Builders.h"
14#include "llvm/ADT/TypeSwitch.h"
33 DenseSet<PrimitiveType> &primsAtLoc = getLeaf(loc);
34 PrimitiveType prim = loc.getPrimitiveType().getValue();
35 if (primsAtLoc.contains(prim))
37 primsAtLoc.insert(prim);
45 DenseSet<PrimitiveType> primsAtLoc = getLeaf(loc);
46 return primsAtLoc.contains(loc.getPrimitiveType().getValue());
50 return placements[loc.getX()][loc.getY()][loc.getNum()];
53void PrimitiveDB::foreach (
54 function_ref<
void(PhysLocationAttr)> callback)
const {
55 for (
const auto &x : placements)
56 for (
const auto &y : x.second)
57 for (
const auto &n : y.second)
58 for (
auto p : n.second)
59 callback(PhysLocationAttr::get(ctxt, PrimitiveTypeAttr::get(ctxt, p),
60 x.first, y.first, n.first));
73 : ctxt(topMod->getContext()), topMod(topMod), seeded(false) {
77 : ctxt(topMod->getContext()), topMod(topMod), seeded(false) {
79 seed.
foreach ([
this](PhysLocationAttr loc) { (void)
getLeaf(loc); });
87 PhysLocationAttr loc, StringRef subPath,
89 StringAttr subPathAttr;
91 subPathAttr = StringAttr::get(inst->getContext(), subPath);
92 auto builder = OpBuilder(inst.getBody());
94 builder, srcLoc, loc, subPathAttr, FlatSymbolRefAttr());
101 LocationVectorAttr locs,
103 auto builder = OpBuilder(inst.getBody());
104 PDRegPhysLocationOp locOp =
105 PDRegPhysLocationOp::create(builder, srcLoc, locs, FlatSymbolRefAttr());
106 for (PhysLocationAttr loc : locs.getLocs())
115 PhysLocationAttr loc) {
120 return op->emitOpError(
"Could not apply placement. Invalid location: ")
122 if (leaf->
locOp !=
nullptr)
123 return op->emitOpError(
"Could not apply placement ")
124 << loc <<
". Position already occupied by "
125 << cast<DynamicInstanceOp>(leaf->
locOp->getParentOp()).getPath();
134 StringRef subPath, Location srcLoc) {
135 StringAttr subPathAttr;
136 if (!subPath.empty())
137 subPathAttr = StringAttr::get(inst->getContext(), subPath);
138 auto builder = OpBuilder::atBlockEnd(&inst.getBody().front());
139 PDPhysRegionOp regOp = PDPhysRegionOp::create(
140 builder, srcLoc, FlatSymbolRefAttr::get(physregion), subPathAttr,
141 FlatSymbolRefAttr());
149 size_t numFailed = 0;
150 inst->walk([&](Operation *op) {
151 LogicalResult added = TypeSwitch<Operation *, LogicalResult>(op)
155 .Case([&](PDRegPhysLocationOp op) {
156 ArrayRef<PhysLocationAttr> locs =
157 op.getLocs().getLocs();
158 for (
auto loc : locs)
163 .Case([&](PDPhysRegionOp op) {
167 .Default([](Operation *op) {
return failure(); });
194 PhysLocationAttr newLoc) {
195 PhysLocationAttr from = locOp.getLoc();
198 locOp.setLocAttr(newLoc);
206 for (PhysLocationAttr loc : locOp.getLocs().getLocs())
216 LocationVectorAttr newLocs) {
217 ArrayRef<PhysLocationAttr> fromLocs = locOp.getLocs().getLocs();
221 for (
auto [from, to] :
llvm::zip(fromLocs, newLocs.getLocs())) {
232 for (
auto [from, to] :
llvm::zip(fromLocs, newLocs.getLocs())) {
246 locOp.setLocsAttr(newLocs);
251 PhysLocationAttr loc) {
252 PlacementCell *leaf =
getLeaf(loc);
253 assert(leaf &&
"Could not find op at location specified by op");
254 assert(leaf->locOp == op);
259 PhysLocationAttr from,
260 PhysLocationAttr to) {
267 if (!oldLeaf || !newLeaf)
270 if (oldLeaf->
locOp ==
nullptr)
271 return op.emitError(
"cannot move from a location not occupied by "
272 "specified op. Currently unoccupied");
273 if (oldLeaf->
locOp != op)
274 return op.emitError(
"cannot move from a location not occupied by "
275 "specified op. Currently occupied by ")
279 "cannot move to new location since location is occupied by ")
280 << cast<DynamicInstanceOp>(newLeaf->
locOp->getParentOp()).getPath();
285 PhysLocationAttr from, PhysLocationAttr to) {
287 "Call `movePlacementCheck` first to ensure that move is legal.");
298 auto innerMap =
placements[loc.getX()][loc.getY()][loc.getNum()];
299 auto instF = innerMap.find(loc.getPrimitiveType().getValue());
300 if (instF == innerMap.end())
302 if (!instF->getSecond().locOp)
304 return instF->getSecond().locOp;
309 uint64_t nearestToY) {
311 PhysLocationAttr nearest = {};
313 [&nearest, nearestToY](PhysLocationAttr loc, Operation *locOp) {
321 std::abs((int64_t)nearestToY - (int64_t)nearest.getY());
322 int64_t replDist = std::abs((int64_t)nearestToY - (int64_t)loc.getY());
323 if (replDist < curDist)
326 std::make_tuple(columnNum, columnNum, -1, -1), prim);
331 PrimitiveType primType = loc.getPrimitiveType().getValue();
335 return &nums[loc.getNum()][primType];
336 if (!nums.count(loc.getNum()))
340 if (primitives.count(primType) == 0)
342 return &primitives[primType];
347 function_ref<
void(PhysLocationAttr, DynInstDataOpInterface)> callback,
348 std::tuple<int64_t, int64_t, int64_t, int64_t> bounds,
349 std::optional<PrimitiveType> primType, std::optional<WalkOrder> walkOrder) {
350 uint64_t xmin = std::get<0>(bounds) < 0 ? 0 : std::get<0>(bounds);
351 uint64_t xmax = std::get<1>(bounds) < 0 ? std::numeric_limits<uint64_t>::max()
352 : (uint64_t)std::get<1>(bounds);
353 uint64_t ymin = std::get<2>(bounds) < 0 ? 0 : std::get<2>(bounds);
354 uint64_t ymax = std::get<3>(bounds) < 0 ? std::numeric_limits<uint64_t>::max()
355 : (uint64_t)std::get<3>(bounds);
362 auto maybeSort = [](
auto &container,
auto direction) {
363 if (!direction.has_value())
368 llvm::sort(container, [direction](
auto colA,
auto colB) {
370 return colA.first < colB.first;
372 return colA.first > colB.first;
377 SmallVector<std::pair<size_t, DimYMap>> cols(
placements.begin(),
379 maybeSort(cols, llvm::transformOptional(walkOrder,
380 [](
auto wo) {
return wo.columns; }));
381 for (
const auto &colF : cols) {
382 size_t x = colF.first;
383 if (x < xmin || x > xmax)
388 SmallVector<std::pair<size_t, DimNumMap>> rows(yMap.begin(), yMap.end());
389 maybeSort(rows, llvm::transformOptional(walkOrder,
390 [](
auto wo) {
return wo.rows; }));
391 for (
const auto &rowF : rows) {
392 size_t y = rowF.first;
393 if (y < ymin || y > ymax)
398 for (
auto &numF : numMap) {
399 size_t num = numF.getFirst();
403 for (
auto &devF : devMap) {
404 PrimitiveType devtype = devF.getFirst();
405 if (primType && devtype != *primType)
407 PlacementCell &inst = devF.getSecond();
410 PhysLocationAttr loc = PhysLocationAttr::get(
411 ctxt, PrimitiveTypeAttr::get(
ctxt, devtype), x, y, num);
412 callback(loc, inst.locOp);
421 function_ref<
void(PDPhysRegionOp)> callback) {
assert(baseType &&"element must be base type")
PlacementDB(MlirModule top, PrimitiveDB *seed)
bool isValidLocation(MlirAttribute loc)
PrimitiveDB(MlirContext ctxt)
bool addPrimitive(MlirAttribute locAndPrim)
LogicalResult movePlacementCheck(DynInstDataOpInterface op, PhysLocationAttr from, PhysLocationAttr to)
Check to make sure the move is going to succeed.
void removePlacement(PDPhysLocationOp)
Remove the placement from the DB and IR. Erases the op.
DynInstDataOpInterface getInstanceAt(PhysLocationAttr)
Lookup the instance at a particular location.
size_t addPlacements(DynamicInstanceOp inst)
Load the placements from inst.
LogicalResult movePlacement(PDPhysLocationOp, PhysLocationAttr)
Move a placement location to a new location.
DenseMap< size_t, DimNumMap > DimYMap
PDPhysLocationOp place(DynamicInstanceOp inst, PhysLocationAttr, StringRef subpath, Location srcLoc)
Assign an instance to a primitive.
size_t addDesignPlacements()
Load the database from the IR.
PlacementCell * getLeaf(PhysLocationAttr)
Get the leaf node.
DenseMap< PrimitiveType, PlacementCell > DimDevType
PlacementDB(mlir::ModuleOp topMod)
Create a placement db containing all the placements in 'topMod'.
void walkRegionPlacements(function_ref< void(PDPhysRegionOp)>)
Walk the region placement information.
PhysLocationAttr getNearestFreeInColumn(PrimitiveType prim, uint64_t column, uint64_t nearestToY)
Find the nearest unoccupied primitive location to 'nearestToY' in 'column'.
PDPhysRegionOp placeIn(DynamicInstanceOp inst, DeclPhysicalRegionOp, StringRef subPath, Location srcLoc)
Assign an operation to a physical region. Return false on failure.
void walkPlacements(function_ref< void(PhysLocationAttr, DynInstDataOpInterface)>, std::tuple< int64_t, int64_t, int64_t, int64_t > bounds=std::make_tuple(-1, -1, -1, -1), std::optional< PrimitiveType > primType={}, std::optional< WalkOrder >={})
Walk the placement information in some sort of reasonable order.
DenseMap< size_t, DimDevType > DimNumMap
LogicalResult insertPlacement(DynInstDataOpInterface op, PhysLocationAttr)
RegionPlacements regionPlacements
A data structure to contain locations of the primitives on the device.
void foreach(function_ref< void(PhysLocationAttr)> callback) const
Iterate over all the primitive locations, executing 'callback' on each one.
DenseSet< PrimitiveType > DimPrimitiveType
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
DynInstDataOpInterface locOp