Index: mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td =================================================================== --- mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td +++ mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td @@ -68,19 +68,6 @@ return op.getAffineMapAttr().getValue(); }] >, - InterfaceMethod< - /*desc=*/"Returns the AffineMapAttr associated with 'memref'.", - /*retTy=*/"NamedAttribute", - /*methodName=*/"getAffineMapAttrForMemRef", - /*args=*/(ins "Value":$memref), - /*methodBody=*/[{}], - /*defaultImplementation=*/[{ - ConcreteOp op = cast(this->getOperation()); - assert(memref == getMemRef()); - return {Identifier::get(op.getMapAttrName(), op.getContext()), - op.getAffineMapAttr()}; - }] - >, InterfaceMethod< /*desc=*/"Returns the value read by this operation.", /*retTy=*/"Value", @@ -148,27 +135,38 @@ }] >, InterfaceMethod< - /*desc=*/"Returns the AffineMapAttr associated with 'memref'.", - /*retTy=*/"NamedAttribute", - /*methodName=*/"getAffineMapAttrForMemRef", - /*args=*/(ins "Value":$memref), + /*desc=*/"Returns the value to store.", + /*retTy=*/"Value", + /*methodName=*/"getValueToStore", + /*args=*/(ins), /*methodBody=*/[{}], /*defaultImplementation=*/[{ ConcreteOp op = cast(this->getOperation()); - assert(memref == getMemRef()); - return {Identifier::get(op.getMapAttrName(), op.getContext()), - op.getAffineMapAttr()}; + return op.getOperand(op.getStoredValOperandIndex()); }] >, + ]; +} + +def AffineMapAccessInterface : OpInterface<"AffineMapAccessInterface"> { + let description = [{ + Interface to query the AffineMap associated with a given MemRef. + The default implementation assumes an operation with a single MemRef, + but the interface can be extended to support multiple MemRefs as well. + }]; + + let methods = [ InterfaceMethod< - /*desc=*/"Returns the value to store.", - /*retTy=*/"Value", - /*methodName=*/"getValueToStore", - /*args=*/(ins), + /*desc=*/"Returns the AffineMapAttr associated with 'memref'.", + /*retTy=*/"NamedAttribute", + /*methodName=*/"getAffineMapAttrForMemRef", + /*args=*/(ins "Value":$memref), /*methodBody=*/[{}], /*defaultImplementation=*/[{ ConcreteOp op = cast(this->getOperation()); - return op.getOperand(op.getStoredValOperandIndex()); + assert(memref == op.getMemRef()); + return {Identifier::get(op.getMapAttrName(), op.getContext()), + op.getAffineMapAttr()}; }] >, ]; Index: mlir/include/mlir/Dialect/Affine/IR/AffineOps.h =================================================================== --- mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -82,7 +82,8 @@ // TODO: Consider replacing src/dst memref indices with view memrefs. class AffineDmaStartOp : public Op { + OpTrait::VariadicOperands, OpTrait::ZeroResult, + AffineMapAccessInterface::Trait> { public: using Op::Op; @@ -191,6 +192,7 @@ getTagMap().getNumInputs()); } + /// Impelements the AffineMapAccessInterface /// Returns the AffineMapAttr associated with 'memref'. NamedAttribute getAffineMapAttrForMemRef(Value memref) { if (memref == getSrcMemRef()) @@ -271,7 +273,8 @@ // class AffineDmaWaitOp : public Op { + OpTrait::VariadicOperands, OpTrait::ZeroResult, + AffineMapAccessInterface::Trait> { public: using Op::Op; @@ -303,6 +306,7 @@ return getTagMemRef().getType().cast().getRank(); } + /// Impelements the AffineMapAccessInterface /// Returns the AffineMapAttr associated with 'memref'. NamedAttribute getAffineMapAttrForMemRef(Value memref) { assert(memref == getTagMemRef()); Index: mlir/include/mlir/Dialect/Affine/IR/AffineOps.td =================================================================== --- mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -458,6 +458,7 @@ class AffineLoadOpBase traits = []> : Affine_Op, + DeclareOpInterfaceMethods, MemRefsNormalizable])> { let arguments = (ins Arg:$memref, @@ -698,7 +699,8 @@ let hasFolder = 1; } -def AffinePrefetchOp : Affine_Op<"prefetch"> { +def AffinePrefetchOp : Affine_Op<"prefetch", + [DeclareOpInterfaceMethods]> { let summary = "affine prefetch operation"; let description = [{ The "affine.prefetch" op prefetches data from a memref location described @@ -752,6 +754,7 @@ return (*this)->getAttr(getMapAttrName()).cast(); } + /// Impelements the AffineMapAccessInterface /// Returns the AffineMapAttr associated with 'memref'. NamedAttribute getAffineMapAttrForMemRef(Value mref) { assert(mref == memref()); @@ -777,6 +780,7 @@ class AffineStoreOpBase traits = []> : Affine_Op, + DeclareOpInterfaceMethods, MemRefsNormalizable])> { code extraClassDeclarationBase = [{ /// Returns the operand index of the value to be stored. @@ -792,6 +796,7 @@ return (*this)->getAttr(getMapAttrName()).cast(); } + /// Impelements the AffineMapAccessInterface /// Returns the AffineMapAttr associated with 'memref'. NamedAttribute getAffineMapAttrForMemRef(Value memref) { assert(memref == getMemRef()); Index: mlir/lib/Transforms/Utils/Utils.cpp =================================================================== --- mlir/lib/Transforms/Utils/Utils.cpp +++ mlir/lib/Transforms/Utils/Utils.cpp @@ -25,22 +25,6 @@ #include "llvm/ADT/TypeSwitch.h" using namespace mlir; -/// Return true if this operation dereferences one or more memref's. -// Temporary utility: will be replaced when this is modeled through -// side-effects/op traits. TODO -static bool isMemRefDereferencingOp(Operation &op) { - return isa(op); -} - -/// Return the AffineMapAttr associated with memory 'op' on 'memref'. -static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value memref) { - return TypeSwitch(op) - .Case( - [=](auto op) { return op.getAffineMapAttrForMemRef(memref); }); -} - // Perform the replacement in `op`. LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, Operation *op, @@ -88,7 +72,8 @@ OpBuilder builder(op); // The following checks if op is dereferencing memref and performs the access // index rewrites. - if (!isMemRefDereferencingOp(*op)) { + auto affMapAcc = dyn_cast(op); + if (!affMapAcc) { if (!allowNonDereferencingOps) // Failure: memref used in a non-dereferencing context (potentially // escapes); no replacement in these cases unless allowNonDereferencingOps @@ -98,7 +83,8 @@ return success(); } // Perform index rewrites for the dereferencing op and then replace the op - NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef); + NamedAttribute oldMapAttrPair = + affMapAcc.getAffineMapAttrForMemRef(oldMemRef); AffineMap oldMap = oldMapAttrPair.second.cast().getValue(); unsigned oldMapNumInputs = oldMap.getNumInputs(); SmallVector oldMapOperands( @@ -272,7 +258,7 @@ // Check if the memref was used in a non-dereferencing context. It is fine // for the memref to be used in a non-dereferencing way outside of the // region where this replacement is happening. - if (!isMemRefDereferencingOp(*op)) { + if (!isa(*op)) { if (!allowNonDereferencingOps) return failure(); // Currently we support the following non-dereferencing ops to be a