diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td @@ -68,19 +68,6 @@ return op.getAffineMapAttr().getValue(); }] >, - InterfaceMethod< - /*desc=*/"Returns the AffineMapAttr associated with 'memref'.", - /*retTy=*/"NamedAttribute", - /*methodName=*/"getAffineMapAttrForMemRef", - /*args=*/(ins "Value":$memref), - /*methodBody=*/[{}], - /*defaultImplementation=*/[{ - ConcreteOp op = cast(this->getOperation()); - assert(memref == getMemRef()); - return {Identifier::get(op.getMapAttrName(), op.getContext()), - op.getAffineMapAttr()}; - }] - >, InterfaceMethod< /*desc=*/"Returns the value read by this operation.", /*retTy=*/"Value", @@ -148,27 +135,40 @@ }] >, InterfaceMethod< - /*desc=*/"Returns the AffineMapAttr associated with 'memref'.", - /*retTy=*/"NamedAttribute", - /*methodName=*/"getAffineMapAttrForMemRef", - /*args=*/(ins "Value":$memref), + /*desc=*/"Returns the value to store.", + /*retTy=*/"Value", + /*methodName=*/"getValueToStore", + /*args=*/(ins), /*methodBody=*/[{}], /*defaultImplementation=*/[{ ConcreteOp op = cast(this->getOperation()); - assert(memref == getMemRef()); - return {Identifier::get(op.getMapAttrName(), op.getContext()), - op.getAffineMapAttr()}; + return op.getOperand(op.getStoredValOperandIndex()); }] >, + ]; +} + +def AffineMapAccessInterface : OpInterface<"AffineMapAccessInterface"> { + let description = [{ + Interface to query the AffineMap used to dereference and access a given + memref. Implementers of this interface must operate on at least one + memref operand. The memref argument given to this interface much match + one of those memref operands. + }]; + + let methods = [ InterfaceMethod< - /*desc=*/"Returns the value to store.", - /*retTy=*/"Value", - /*methodName=*/"getValueToStore", - /*args=*/(ins), + /*desc=*/"Returns the AffineMapAttr associated with 'memref'.", + /*retTy=*/"NamedAttribute", + /*methodName=*/"getAffineMapAttrForMemRef", + /*args=*/(ins "Value":$memref), /*methodBody=*/[{}], /*defaultImplementation=*/[{ ConcreteOp op = cast(this->getOperation()); - return op.getOperand(op.getStoredValOperandIndex()); + assert(memref == op.getMemRef() && + "Expected memref argument to match memref operand"); + return {Identifier::get(op.getMapAttrName(), op.getContext()), + op.getAffineMapAttr()}; }] >, ]; diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -82,7 +82,8 @@ // TODO: Consider replacing src/dst memref indices with view memrefs. class AffineDmaStartOp : public Op { + OpTrait::VariadicOperands, OpTrait::ZeroResult, + AffineMapAccessInterface::Trait> { public: using Op::Op; @@ -191,6 +192,7 @@ getTagMap().getNumInputs()); } + /// Impelements the AffineMapAccessInterface. /// Returns the AffineMapAttr associated with 'memref'. NamedAttribute getAffineMapAttrForMemRef(Value memref) { if (memref == getSrcMemRef()) @@ -271,7 +273,8 @@ // class AffineDmaWaitOp : public Op { + OpTrait::VariadicOperands, OpTrait::ZeroResult, + AffineMapAccessInterface::Trait> { public: using Op::Op; @@ -303,6 +306,7 @@ return getTagMemRef().getType().cast().getRank(); } + /// Impelements the AffineMapAccessInterface. /// Returns the AffineMapAttr associated with 'memref'. NamedAttribute getAffineMapAttrForMemRef(Value memref) { assert(memref == getTagMemRef()); diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -458,6 +458,7 @@ class AffineLoadOpBase traits = []> : Affine_Op, + DeclareOpInterfaceMethods, MemRefsNormalizable])> { let arguments = (ins Arg:$memref, @@ -698,7 +699,8 @@ let hasFolder = 1; } -def AffinePrefetchOp : Affine_Op<"prefetch"> { +def AffinePrefetchOp : Affine_Op<"prefetch", + [DeclareOpInterfaceMethods]> { let summary = "affine prefetch operation"; let description = [{ The "affine.prefetch" op prefetches data from a memref location described @@ -752,9 +754,11 @@ return (*this)->getAttr(getMapAttrName()).cast(); } + /// Impelements the AffineMapAccessInterface. /// Returns the AffineMapAttr associated with 'memref'. NamedAttribute getAffineMapAttrForMemRef(Value mref) { - assert(mref == memref()); + assert(mref == memref() && + "Expected mref argument to match memref operand"); return {Identifier::get(getMapAttrName(), getContext()), getAffineMapAttr()}; } @@ -777,6 +781,7 @@ class AffineStoreOpBase traits = []> : Affine_Op, + DeclareOpInterfaceMethods, MemRefsNormalizable])> { code extraClassDeclarationBase = [{ /// Returns the operand index of the value to be stored. @@ -792,13 +797,6 @@ return (*this)->getAttr(getMapAttrName()).cast(); } - /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(Value memref) { - assert(memref == getMemRef()); - return {Identifier::get(getMapAttrName(), getContext()), - getAffineMapAttr()}; - } - static StringRef getMapAttrName() { return "map"; } }]; } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -61,11 +61,6 @@ SmallPtrSetImpl &definedOps, SmallPtrSetImpl &opsToHoist); -static bool isMemRefDereferencingOp(Operation &op) { - // TODO: Support DMA Ops. - return isa(op); -} - // Returns true if the individual op is loop invariant. bool isOpLoopInvariant(Operation &op, Value indVar, SmallPtrSetImpl &definedOps, @@ -89,7 +84,7 @@ // which are themselves not being hoisted. definedOps.insert(&op); - if (isMemRefDereferencingOp(op)) { + if (isa(op)) { Value memref = isa(op) ? cast(op).getMemRef() : cast(op).getMemRef(); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -68,12 +68,6 @@ maximalFusion); } -// TODO: Replace when this is modeled through side-effects/op traits -static bool isMemRefDereferencingOp(Operation &op) { - return isa(op); -} - namespace { // LoopNestStateCollector walks loop nests and collects load and store @@ -264,7 +258,7 @@ return true; // Return true if any use of 'memref' escapes the function. for (auto *user : memref.getUsers()) - if (!isMemRefDereferencingOp(*user)) + if (!isa(*user)) return true; } return false; @@ -703,7 +697,7 @@ // Check if 'memref' escapes through a non-affine op (e.g., std load/store, // call op, etc.). for (Operation *user : memref.getUsers()) - if (!isMemRefDereferencingOp(*user)) + if (!isa(*user)) escapingMemRefs.insert(memref); } } @@ -979,7 +973,7 @@ // Interrupt the walk if found. auto walkResult = op->walk([&](Operation *user) { // Skip affine ops. - if (isMemRefDereferencingOp(*user)) + if (isa(*user)) return WalkResult::advance(); // Find a non-affine op that uses the memref. if (llvm::is_contained(users, user)) diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -25,22 +25,6 @@ #include "llvm/ADT/TypeSwitch.h" using namespace mlir; -/// Return true if this operation dereferences one or more memref's. -// Temporary utility: will be replaced when this is modeled through -// side-effects/op traits. TODO -static bool isMemRefDereferencingOp(Operation &op) { - return isa(op); -} - -/// Return the AffineMapAttr associated with memory 'op' on 'memref'. -static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value memref) { - return TypeSwitch(op) - .Case( - [=](auto op) { return op.getAffineMapAttrForMemRef(memref); }); -} - // Perform the replacement in `op`. LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, Operation *op, @@ -88,17 +72,20 @@ OpBuilder builder(op); // The following checks if op is dereferencing memref and performs the access // index rewrites. - if (!isMemRefDereferencingOp(*op)) { - if (!allowNonDereferencingOps) + auto affMapAccInterface = dyn_cast(op); + if (!affMapAccInterface) { + if (!allowNonDereferencingOps) { // Failure: memref used in a non-dereferencing context (potentially // escapes); no replacement in these cases unless allowNonDereferencingOps // is set. return failure(); + } op->setOperand(memRefOperandPos, newMemRef); return success(); } // Perform index rewrites for the dereferencing op and then replace the op - NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef); + NamedAttribute oldMapAttrPair = + affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef); AffineMap oldMap = oldMapAttrPair.second.cast().getValue(); unsigned oldMapNumInputs = oldMap.getNumInputs(); SmallVector oldMapOperands( @@ -272,7 +259,7 @@ // Check if the memref was used in a non-dereferencing context. It is fine // for the memref to be used in a non-dereferencing way outside of the // region where this replacement is happening. - if (!isMemRefDereferencingOp(*op)) { + if (!isa(*op)) { if (!allowNonDereferencingOps) return failure(); // Currently we support the following non-dereferencing ops to be a