diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td @@ -68,6 +68,19 @@ return op.getAffineMapAttr().getValue(); }] >, + InterfaceMethod< + /*desc=*/"Returns the AffineMapAttr associated with 'memref'.", + /*retTy=*/"NamedAttribute", + /*methodName=*/"getAffineMapAttrForMemRef", + /*args=*/(ins "Value":$memref), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + ConcreteOp op = cast(this->getOperation()); + assert(memref == getMemRef()); + return {Identifier::get(op.getMapAttrName(), op.getContext()), + op.getAffineMapAttr()}; + }] + >, ]; } @@ -124,6 +137,19 @@ return op.getAffineMapAttr().getValue(); }] >, + InterfaceMethod< + /*desc=*/"Returns the AffineMapAttr associated with 'memref'.", + /*retTy=*/"NamedAttribute", + /*methodName=*/"getAffineMapAttrForMemRef", + /*args=*/(ins "Value":$memref), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + ConcreteOp op = cast(this->getOperation()); + assert(memref == getMemRef()); + return {Identifier::get(op.getMapAttrName(), op.getContext()), + op.getAffineMapAttr()}; + }] + >, ]; } diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -389,13 +389,6 @@ return getAttr(getMapAttrName()).cast(); } - /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(Value memref) { - assert(memref == getMemRef()); - return {Identifier::get(getMapAttrName(), getContext()), - getAffineMapAttr()}; - } - static StringRef getMapAttrName() { return "map"; } }]; } diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -38,10 +38,10 @@ static void getLoadAndStoreMemRefAccesses(Operation *opA, DenseMap &values) { opA->walk([&](Operation *op) { - if (auto loadOp = dyn_cast(op)) { + if (auto loadOp = dyn_cast(op)) { if (values.count(loadOp.getMemRef()) == 0) values[loadOp.getMemRef()] = false; - } else if (auto storeOp = dyn_cast(op)) { + } else if (auto storeOp = dyn_cast(op)) { values[storeOp.getMemRef()] = true; } }); @@ -52,10 +52,10 @@ // Returns false otherwise. static bool isDependentLoadOrStoreOp(Operation *op, DenseMap &values) { - if (auto loadOp = dyn_cast(op)) { + if (auto loadOp = dyn_cast(op)) { return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()] == true; - } else if (auto storeOp = dyn_cast(op)) { + } else if (auto storeOp = dyn_cast(op)) { return values.count(storeOp.getMemRef()) > 0; } return false; @@ -105,7 +105,7 @@ it != Block::reverse_iterator(opA); ++it) { Operation *opX = &(*it); opX->walk([&](Operation *op) { - if (isa(op) || isa(op)) { + if (isa(op) || isa(op)) { if (isDependentLoadOrStoreOp(op, values)) { lastDepOp = opX; return WalkResult::interrupt(); @@ -179,7 +179,7 @@ SmallVectorImpl &loadAndStoreOps) { bool hasIfOp = false; forOp.walk([&](Operation *op) { - if (isa(op) || isa(op)) + if (isa(op) || isa(op)) loadAndStoreOps.push_back(op); else if (isa(op)) hasIfOp = true; @@ -464,7 +464,7 @@ unsigned storeCount = 0; llvm::SmallDenseSet storeMemrefs; srcForOp.walk([&](Operation *op) { - if (auto storeOp = dyn_cast(op)) { + if (auto storeOp = dyn_cast(op)) { storeMemrefs.insert(storeOp.getMemRef()); ++storeCount; } @@ -476,7 +476,7 @@ // 'insertPointParent'. for (auto value : storeMemrefs) { for (auto *user : value.getUsers()) { - if (auto loadOp = dyn_cast(user)) { + if (auto loadOp = dyn_cast(user)) { SmallVector loops; // Check if any loop in loop nest surrounding 'user' is // 'insertPointParent'. diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -30,7 +30,7 @@ // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) static bool isMemRefDereferencingOp(Operation &op) { - if (isa(op) || isa(op) || + if (isa(op) || isa(op) || isa(op) || isa(op)) return true; return false; @@ -39,8 +39,8 @@ /// Return the AffineMapAttr associated with memory 'op' on 'memref'. static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value memref) { return TypeSwitch(op) - .Case( + .Case( [=](auto op) { return op.getAffineMapAttrForMemRef(memref); }); } diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -2464,3 +2464,32 @@ // MAXIMAL-NEXT: affine.for // MAXIMAL-NOT: affine.for // MAXIMAL: return + +// ----- + +// CHECK-LABEL: func @vector_loop +func @vector_loop(%a : memref<10x20xf32>, %b : memref<10x20xf32>, + %c : memref<10x20xf32>) { + affine.for %j = 0 to 10 { + affine.for %i = 0 to 5 { + %ld0 = affine.vector_load %a[%j, %i*4] : memref<10x20xf32>, vector<4xf32> + affine.vector_store %ld0, %b[%j, %i*4] : memref<10x20xf32>, vector<4xf32> + } + } + + affine.for %j = 0 to 10 { + affine.for %i = 0 to 5 { + %ld0 = affine.vector_load %b[%j, %i*4] : memref<10x20xf32>, vector<4xf32> + affine.vector_store %ld0, %c[%j, %i*4] : memref<10x20xf32>, vector<4xf32> + } + } + + return +} +// CHECK: affine.for +// CHECK-NEXT: affine.for +// CHECK-NEXT: affine.vector_load +// CHECK-NEXT: affine.vector_store +// CHECK-NEXT: affine.vector_load +// CHECK-NEXT: affine.vector_store +// CHECK-NOT: affine.for