diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h @@ -0,0 +1,24 @@ +//===- AffineMemoryOpInterfaces.h -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces for affine memory ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_AFFINEMEMORYOPINTERFACES_H_ +#define MLIR_INTERFACES_AFFINEMEMORYOPINTERFACES_H_ + +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h.inc" +} // namespace mlir + +#endif // MLIR_INTERFACES_AFFINEMEMORYOPINTERFACES_H_ diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td @@ -0,0 +1,128 @@ +//===- AffineMemoryOpInterfaces.td -------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces for affine memory ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_AFFINEMEMORYOPINTERFACES +#define MLIR_AFFINEMEMORYOPINTERFACES + +include "mlir/IR/OpBase.td" + +def AffineReadOpInterface : OpInterface<"AffineReadOpInterface"> { + let description = [{ + Interface to query characteristics of read-like ops with affine + restrictions. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ Returns the memref operand to read from. }], + /*retTy=*/"Value", + /*methodName=*/"getMemRef", + /*args=*/(ins), + /*methodBody*/[{}], + /*defaultImplementation=*/ [{ + ConcreteOp op = cast(this->getOperation()); + return op.getOperand(op.getMemRefOperandIndex()); + }] + >, + InterfaceMethod< + /*desc=*/[{ Returns the type of the memref operand. }], + /*retTy=*/"MemRefType", + /*methodName=*/"getMemRefType", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + ConcreteOp op = cast(this->getOperation()); + return op.getMemRef().getType().template cast(); + }] + >, + InterfaceMethod< + /*desc=*/[{ Returns affine map operands. }], + /*retTy=*/"Operation::operand_range", + /*methodName=*/"getMapOperands", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + ConcreteOp op = cast(this->getOperation()); + return llvm::drop_begin(op.getOperands(), 1); + }] + >, + InterfaceMethod< + /*desc=*/[{ Returns the affine map used to index the memref for this + operation. }], + /*retTy=*/"AffineMap", + /*methodName=*/"getAffineMap", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + ConcreteOp op = cast(this->getOperation()); + return op.getAffineMapAttr().getValue(); + }] + >, + ]; +} + +def AffineWriteOpInterface : OpInterface<"AffineWriteOpInterface"> { + let description = [{ + Interface to query characteristics of write-like ops with affine + restrictions. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ Returns the memref operand to write to. }], + /*retTy=*/"Value", + /*methodName=*/"getMemRef", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + ConcreteOp op = cast(this->getOperation()); + return op.getOperand(op.getMemRefOperandIndex()); + }] + >, + InterfaceMethod< + /*desc=*/[{ Returns the type of the memref operand. }], + /*retTy=*/"MemRefType", + /*methodName=*/"getMemRefType", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + ConcreteOp op = cast(this->getOperation()); + return op.getMemRef().getType().template cast(); + }] + >, + InterfaceMethod< + /*desc=*/[{ Returns affine map operands. }], + /*retTy=*/"Operation::operand_range", + /*methodName=*/"getMapOperands", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + ConcreteOp op = cast(this->getOperation()); + return llvm::drop_begin(op.getOperands(), 2); + }] + >, + InterfaceMethod< + /*desc=*/[{ Returns the affine map used to index the memref for this + operation. }], + /*retTy=*/"AffineMap", + /*methodName=*/"getAffineMap", + /*args=*/(ins), + /*methodName=*/[{}], + /*defaultImplementation=*/[{ + ConcreteOp op = cast(this->getOperation()); + return op.getAffineMapAttr().getValue(); + }] + >, + ]; +} + +#endif // MLIR_AFFINEMEMORYOPINTERFACES diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H #define MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H +#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -14,6 +14,7 @@ #define AFFINE_OPS include "mlir/Dialect/Affine/IR/AffineOpsBase.td" +include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -371,7 +372,8 @@ } class AffineLoadOpBase traits = []> : - Affine_Op { + Affine_Op])> { let arguments = (ins Arg:$memref, Variadic:$indices); @@ -380,18 +382,9 @@ /// Returns the operand index of the memref. unsigned getMemRefOperandIndex() { return 0; } - /// Get memref operand. - Value getMemRef() { return getOperand(getMemRefOperandIndex()); } void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); } - MemRefType getMemRefType() { - return getMemRef().getType().cast(); - } - - /// Get affine map operands. - operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); } /// Returns the affine map used to index the memref for this operation. - AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } AffineMapAttr getAffineMapAttr() { return getAttr(getMapAttrName()).cast(); } @@ -407,7 +400,7 @@ }]; } -def AffineLoadOp : AffineLoadOpBase<"load", []> { +def AffineLoadOp : AffineLoadOpBase<"load"> { let summary = "affine load operation"; let description = [{ The "affine.load" op reads an element from a memref, where the index @@ -666,8 +659,8 @@ } class AffineStoreOpBase traits = []> : - Affine_Op { - + Affine_Op])> { code extraClassDeclarationBase = [{ /// Get value to be stored by store operation. Value getValueToStore() { return getOperand(0); } @@ -675,19 +668,9 @@ /// Returns the operand index of the memref. unsigned getMemRefOperandIndex() { return 1; } - /// Get memref operand. - Value getMemRef() { return getOperand(getMemRefOperandIndex()); } void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); } - MemRefType getMemRefType() { - return getMemRef().getType().cast(); - } - - /// Get affine map operands. - operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); } - /// Returns the affine map used to index the memref for this operation. - AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } AffineMapAttr getAffineMapAttr() { return getAttr(getMapAttrName()).cast(); } @@ -703,7 +686,7 @@ }]; } -def AffineStoreOp : AffineStoreOpBase<"store", []> { +def AffineStoreOp : AffineStoreOpBase<"store"> { let summary = "affine store operation"; let description = [{ The "affine.store" op writes an element to a memref, where the index @@ -776,7 +759,7 @@ let verifier = ?; } -def AffineVectorLoadOp : AffineLoadOpBase<"vector_load", []> { +def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> { let summary = "affine vector load operation"; let description = [{ The "affine.vector_load" is the vector counterpart of @@ -825,7 +808,7 @@ }]; } -def AffineVectorStoreOp : AffineStoreOpBase<"vector_store", []> { +def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> { let summary = "affine vector store operation"; let description = [{ The "affine.vector_store" is the vector counterpart of diff --git a/mlir/include/mlir/Dialect/Affine/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Affine/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Affine/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Affine/IR/CMakeLists.txt @@ -1,2 +1,10 @@ add_mlir_dialect(AffineOps affine) add_mlir_doc(AffineOps -gen-op-doc AffineOps Dialects/) + +set(LLVM_TARGET_DEFINITIONS AffineMemoryOpInterfaces.td) +mlir_tablegen(AffineMemoryOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(AffineMemoryOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRAffineMemoryOpInterfacesIncGen) +add_dependencies(mlir-generic-headers MLIRAffineMemoryOpInterfacesIncGen) + +add_dependencies(MLIRAffineOpsIncGen MLIRAffineMemoryOpInterfacesIncGen) diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -660,10 +660,12 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // Get affine map from AffineLoad/Store. AffineMap map; - if (auto loadOp = dyn_cast(opInst)) + if (auto loadOp = dyn_cast(opInst)) { map = loadOp.getAffineMap(); - else if (auto storeOp = dyn_cast(opInst)) + } else { + auto storeOp = cast(opInst); map = storeOp.getAffineMap(); + } SmallVector operands(indices.begin(), indices.end()); fullyComposeAffineMapAndOperands(&map, &operands); map = simplifyAffineMap(map); @@ -771,9 +773,10 @@ if (srcAccess.memref != dstAccess.memref) return DependenceResult::NoDependence; - // Return 'NoDependence' if one of these accesses is not an AffineStoreOp. - if (!allowRAR && !isa(srcAccess.opInst) && - !isa(dstAccess.opInst)) + // Return 'NoDependence' if one of these accesses is not an + // AffineWriteOpInterface. + if (!allowRAR && !isa(srcAccess.opInst) && + !isa(dstAccess.opInst)) return DependenceResult::NoDependence; // Get composed access function for 'srcAccess'. @@ -857,7 +860,8 @@ // Collect all load and store ops in loop nest rooted at 'forOp'. SmallVector loadAndStoreOpInsts; forOp.getOperation()->walk([&](Operation *opInst) { - if (isa(opInst) || isa(opInst)) + if (isa(opInst) || + isa(opInst)) loadAndStoreOpInsts.push_back(opInst); }); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -196,8 +196,8 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, ComputationSliceState *sliceState, bool addMemRefDimBounds) { - assert((isa(op) || isa(op)) && - "affine load/store op expected"); + assert((isa(op) || isa(op)) && + "affine read/write op expected"); MemRefAccess access(op); memref = access.memref; @@ -404,9 +404,10 @@ template LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp, bool emitError) { - static_assert( - llvm::is_one_of::value, - "argument should be either a AffineLoadOp or a AffineStoreOp"); + static_assert(llvm::is_one_of::value, + "argument should be either a AffineReadOpInterface or a " + "AffineWriteOpInterface"); Operation *op = loadOrStoreOp.getOperation(); MemRefRegion region(op->getLoc()); @@ -456,10 +457,10 @@ } // Explicitly instantiate the template so that the compiler knows we need them! -template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineLoadOp loadOp, - bool emitError); -template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineStoreOp storeOp, - bool emitError); +template LogicalResult +mlir::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp, bool emitError); +template LogicalResult +mlir::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp, bool emitError); // Returns in 'positions' the Block positions of 'op' in each ancestor // Block from the Block containing operation, stopping at 'limitBlock'. @@ -575,8 +576,8 @@ return failure(); } - bool readReadAccesses = isa(srcAccess.opInst) && - isa(dstAccess.opInst); + bool readReadAccesses = isa(srcAccess.opInst) && + isa(dstAccess.opInst); FlatAffineConstraints dependenceConstraints; // Check dependence between 'srcAccess' and 'dstAccess'. DependenceResult result = checkMemrefAccessDependence( @@ -768,7 +769,8 @@ : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end()); llvm::SmallDenseSet sequentialLoops; - if (isa(depSourceOp) && isa(depSinkOp)) { + if (isa(depSourceOp) && + isa(depSinkOp)) { // For read-read access pairs, clear any slice bounds on sequential loops. // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'. getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0], @@ -865,7 +867,7 @@ // Constructs MemRefAccess populating it with the memref, its indices and // opinst from 'loadOrStoreOpInst'. MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { - if (auto loadOp = dyn_cast(loadOrStoreOpInst)) { + if (auto loadOp = dyn_cast(loadOrStoreOpInst)) { memref = loadOp.getMemRef(); opInst = loadOrStoreOpInst; auto loadMemrefType = loadOp.getMemRefType(); @@ -874,8 +876,9 @@ indices.push_back(index); } } else { - assert(isa(loadOrStoreOpInst) && "load/store op expected"); - auto storeOp = dyn_cast(loadOrStoreOpInst); + assert(isa(loadOrStoreOpInst) && + "Affine read/write op expected"); + auto storeOp = cast(loadOrStoreOpInst); opInst = loadOrStoreOpInst; memref = storeOp.getMemRef(); auto storeMemrefType = storeOp.getMemRefType(); @@ -890,7 +893,9 @@ return memref.getType().cast().getRank(); } -bool MemRefAccess::isStore() const { return isa(opInst); } +bool MemRefAccess::isStore() const { + return isa(opInst); +} /// Returns the nesting depth of this statement, i.e., the number of loops /// surrounding this statement. @@ -947,7 +952,8 @@ // Walk this 'affine.for' operation to gather all memory regions. auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult { - if (!isa(opInst) && !isa(opInst)) { + if (!isa(opInst) && + !isa(opInst)) { // Neither load nor a store op. return WalkResult::advance(); } @@ -1007,7 +1013,8 @@ // Collect all load and store ops in loop nest rooted at 'forOp'. SmallVector loadAndStoreOpInsts; auto walkResult = forOp.walk([&](Operation *opInst) -> WalkResult { - if (isa(opInst) || isa(opInst)) + if (isa(opInst) || + isa(opInst)) loadAndStoreOpInsts.push_back(opInst); else if (!isa(opInst) && !isa(opInst) && !isa(opInst) && diff --git a/mlir/lib/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp b/mlir/lib/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp @@ -0,0 +1,18 @@ +//===- AffineMemoryOpInterfaces.cpp - Loop-like operations in MLIR --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Affine Memory Op Interfaces +//===----------------------------------------------------------------------===// + +/// Include the definitions of the affine memory op interfaces. +#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp.inc" diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRAffineOps + AffineMemoryOpInterfaces.cpp AffineOps.cpp AffineValueMap.cpp @@ -6,6 +7,7 @@ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine DEPENDS + MLIRAffineMemoryOpInterfacesIncGen MLIRAffineOpsIncGen LINK_LIBS PUBLIC diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -70,7 +70,7 @@ // TODO(b/117228571) Replace when this is modeled through side-effects/op traits static bool isMemRefDereferencingOp(Operation &op) { - if (isa(op) || isa(op) || + if (isa(op) || isa(op) || isa(op) || isa(op)) return true; return false; @@ -92,9 +92,9 @@ forOps.push_back(cast(op)); else if (op->getNumRegions() != 0) hasNonForRegion = true; - else if (isa(op)) + else if (isa(op)) loadOpInsts.push_back(op); - else if (isa(op)) + else if (isa(op)) storeOpInsts.push_back(op); }); } @@ -125,7 +125,7 @@ unsigned getLoadOpCount(Value memref) { unsigned loadOpCount = 0; for (auto *loadOpInst : loads) { - if (memref == cast(loadOpInst).getMemRef()) + if (memref == cast(loadOpInst).getMemRef()) ++loadOpCount; } return loadOpCount; @@ -135,7 +135,7 @@ unsigned getStoreOpCount(Value memref) { unsigned storeOpCount = 0; for (auto *storeOpInst : stores) { - if (memref == cast(storeOpInst).getMemRef()) + if (memref == cast(storeOpInst).getMemRef()) ++storeOpCount; } return storeOpCount; @@ -145,7 +145,7 @@ void getStoreOpsForMemref(Value memref, SmallVectorImpl *storeOps) { for (auto *storeOpInst : stores) { - if (memref == cast(storeOpInst).getMemRef()) + if (memref == cast(storeOpInst).getMemRef()) storeOps->push_back(storeOpInst); } } @@ -154,7 +154,7 @@ void getLoadOpsForMemref(Value memref, SmallVectorImpl *loadOps) { for (auto *loadOpInst : loads) { - if (memref == cast(loadOpInst).getMemRef()) + if (memref == cast(loadOpInst).getMemRef()) loadOps->push_back(loadOpInst); } } @@ -164,10 +164,10 @@ void getLoadAndStoreMemrefSet(DenseSet *loadAndStoreMemrefSet) { llvm::SmallDenseSet loadMemrefs; for (auto *loadOpInst : loads) { - loadMemrefs.insert(cast(loadOpInst).getMemRef()); + loadMemrefs.insert(cast(loadOpInst).getMemRef()); } for (auto *storeOpInst : stores) { - auto memref = cast(storeOpInst).getMemRef(); + auto memref = cast(storeOpInst).getMemRef(); if (loadMemrefs.count(memref) > 0) loadAndStoreMemrefSet->insert(memref); } @@ -259,7 +259,7 @@ bool writesToLiveInOrEscapingMemrefs(unsigned id) { Node *node = getNode(id); for (auto *storeOpInst : node->stores) { - auto memref = cast(storeOpInst).getMemRef(); + auto memref = cast(storeOpInst).getMemRef(); auto *op = memref.getDefiningOp(); // Return true if 'memref' is a block argument. if (!op) @@ -272,13 +272,14 @@ return false; } - // Returns the unique AffineStoreOp in `node` that meets all the following: + // Returns the unique AffineWriteOpInterface in `node` that meets all the + // following: // *) store is the only one that writes to a function-local memref live out // of `node`, // *) store is not the source of a self-dependence on `node`. - // Otherwise, returns a null AffineStoreOp. - AffineStoreOp getUniqueOutgoingStore(Node *node) { - AffineStoreOp uniqueStore; + // Otherwise, returns a null AffineWriteOpInterface. + AffineWriteOpInterface getUniqueOutgoingStore(Node *node) { + AffineWriteOpInterface uniqueStore; // Return null if `node` doesn't have any outgoing edges. auto outEdgeIt = outEdges.find(node->id); @@ -287,7 +288,7 @@ const auto &nodeOutEdges = outEdgeIt->second; for (auto *op : node->stores) { - auto storeOp = cast(op); + auto storeOp = cast(op); auto memref = storeOp.getMemRef(); // Skip this store if there are no dependences on its memref. This means // that store either: @@ -322,7 +323,8 @@ Node *node = getNode(id); for (auto *storeOpInst : node->stores) { // Return false if there exist out edges from 'id' on 'memref'. - if (getOutEdgeCount(id, cast(storeOpInst).getMemRef()) > 0) + auto storeMemref = cast(storeOpInst).getMemRef(); + if (getOutEdgeCount(id, storeMemref) > 0) return false; } return true; @@ -651,28 +653,28 @@ Node node(nextNodeId++, &op); for (auto *opInst : collector.loadOpInsts) { node.loads.push_back(opInst); - auto memref = cast(opInst).getMemRef(); + auto memref = cast(opInst).getMemRef(); memrefAccesses[memref].insert(node.id); } for (auto *opInst : collector.storeOpInsts) { node.stores.push_back(opInst); - auto memref = cast(opInst).getMemRef(); + auto memref = cast(opInst).getMemRef(); memrefAccesses[memref].insert(node.id); } forToNodeMap[&op] = node.id; nodes.insert({node.id, node}); - } else if (auto loadOp = dyn_cast(op)) { + } else if (auto loadOp = dyn_cast(op)) { // Create graph node for top-level load op. Node node(nextNodeId++, &op); node.loads.push_back(&op); - auto memref = cast(op).getMemRef(); + auto memref = cast(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (auto storeOp = dyn_cast(op)) { + } else if (auto storeOp = dyn_cast(op)) { // Create graph node for top-level store op. Node node(nextNodeId++, &op); node.stores.push_back(&op); - auto memref = cast(op).getMemRef(); + auto memref = cast(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (op.getNumRegions() != 0) { @@ -733,7 +735,7 @@ dstLoads->clear(); SmallVector srcLoadsToKeep; for (auto *load : *srcLoads) { - if (cast(load).getMemRef() == memref) + if (cast(load).getMemRef() == memref) dstLoads->push_back(load); else srcLoadsToKeep.push_back(load); @@ -854,7 +856,7 @@ // Builder to create constants at the top level. OpBuilder top(forInst->getParentOfType().getBody()); // Create new memref type based on slice bounds. - auto oldMemRef = cast(srcStoreOpInst).getMemRef(); + auto oldMemRef = cast(srcStoreOpInst).getMemRef(); auto oldMemRefType = oldMemRef.getType().cast(); unsigned rank = oldMemRefType.getRank(); @@ -962,9 +964,10 @@ // Returns true if 'dstNode's read/write region to 'memref' is a super set of // 'srcNode's write region to 'memref' and 'srcId' has only one output edge. // TODO(andydavis) Generalize this to handle more live in/out cases. -static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, - AffineStoreOp srcLiveOutStoreOp, - MemRefDependenceGraph *mdg) { +static bool +canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, + AffineWriteOpInterface srcLiveOutStoreOp, + MemRefDependenceGraph *mdg) { assert(srcLiveOutStoreOp && "Expected a valid store op"); auto *dstNode = mdg->getNode(dstId); Value memref = srcLiveOutStoreOp.getMemRef(); @@ -1450,7 +1453,7 @@ DenseSet visitedMemrefs; while (!loads.empty()) { // Get memref of load on top of the stack. - auto memref = cast(loads.back()).getMemRef(); + auto memref = cast(loads.back()).getMemRef(); if (visitedMemrefs.count(memref) > 0) continue; visitedMemrefs.insert(memref); @@ -1488,7 +1491,7 @@ // feasibility for loops with multiple stores. unsigned maxLoopDepth = 0; for (auto *op : srcNode->stores) { - auto storeOp = cast(op); + auto storeOp = cast(op); if (storeOp.getMemRef() != memref) { srcStoreOp = nullptr; break; @@ -1563,7 +1566,7 @@ // Gather 'dstNode' store ops to 'memref'. SmallVector dstStoreOpInsts; for (auto *storeOpInst : dstNode->stores) - if (cast(storeOpInst).getMemRef() == memref) + if (cast(storeOpInst).getMemRef() == memref) dstStoreOpInsts.push_back(storeOpInst); unsigned bestDstLoopDepth; @@ -1601,7 +1604,8 @@ // Create private memref for 'memref' in 'dstAffineForOp'. SmallVector storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { - if (cast(storeOpInst).getMemRef() == memref) + if (cast(storeOpInst).getMemRef() == + memref) storesForMemref.push_back(storeOpInst); } // TODO(andydavis) Use union of memref write regions to compute @@ -1624,7 +1628,8 @@ // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { - auto loadMemRef = cast(loadOpInst).getMemRef(); + auto loadMemRef = + cast(loadOpInst).getMemRef(); // NOTE: Change 'loads' to a hash set in case efficiency is an // issue. We still use a vector since it's expected to be small. if (visitedMemrefs.count(loadMemRef) == 0 && @@ -1785,7 +1790,8 @@ // Check that all stores are to the same memref. DenseSet storeMemrefs; for (auto *storeOpInst : sibNode->stores) { - storeMemrefs.insert(cast(storeOpInst).getMemRef()); + storeMemrefs.insert( + cast(storeOpInst).getMemRef()); } if (storeMemrefs.size() != 1) return false; @@ -1796,7 +1802,7 @@ auto fn = dstNode->op->getParentOfType(); for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) { for (auto *user : fn.getArgument(i).getUsers()) { - if (auto loadOp = dyn_cast(user)) { + if (auto loadOp = dyn_cast(user)) { // Gather loops surrounding 'use'. SmallVector loops; getLoopIVs(*user, &loops); diff --git a/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp b/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp --- a/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp +++ b/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp @@ -37,8 +37,9 @@ void TestMemRefBoundCheck::runOnFunction() { getFunction().walk([](Operation *opInst) { - TypeSwitch(opInst).Case( - [](auto op) { boundCheckLoadOrStoreOp(op); }); + TypeSwitch(opInst) + .Case( + [](auto op) { boundCheckLoadOrStoreOp(op); }); // TODO(bondhugula): do this for DMA ops as well. });