diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td @@ -81,6 +81,16 @@ op.getAffineMapAttr()}; }] >, + InterfaceMethod< + /*desc=*/"Returns the value read by this operation.", + /*retTy=*/"Value", + /*methodName=*/"getValue", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return cast(this->getOperation()); + }] + >, ]; } @@ -150,6 +160,17 @@ op.getAffineMapAttr()}; }] >, + InterfaceMethod< + /*desc=*/"Returns the value to store.", + /*retTy=*/"Value", + /*methodName=*/"getValueToStore", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + ConcreteOp op = cast(this->getOperation()); + return op.getOperand(op.getStoredValOperandIndex()); + }] + >, ]; } diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -725,8 +725,8 @@ Affine_Op])> { code extraClassDeclarationBase = [{ - /// Get value to be stored by store operation. - Value getValueToStore() { return getOperand(0); } + /// Returns the operand index of the value to be stored. + unsigned getStoredValOperandIndex() { return 0; } /// Returns the operand index of the memref. unsigned getMemRefOperandIndex() { return 1; } diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -63,7 +63,7 @@ struct MemRefDataFlowOpt : public MemRefDataFlowOptBase { void runOnFunction() override; - void forwardStoreToLoad(AffineLoadOp loadOp); + void forwardStoreToLoad(AffineReadOpInterface loadOp); // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; @@ -84,14 +84,14 @@ // This is a straightforward implementation not optimized for speed. Optimize // if needed. -void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) { +void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) { // First pass over the use list to get the minimum number of surrounding // loops common between the load op and the store op, with min taken across // all store ops. SmallVector storeOps; unsigned minSurroundingLoops = getNestingDepth(loadOp); for (auto *user : loadOp.getMemRef().getUsers()) { - auto storeOp = dyn_cast(user); + auto storeOp = dyn_cast(user); if (!storeOp) continue; unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp); @@ -167,8 +167,9 @@ return; // Perform the actual store to load forwarding. - Value storeVal = cast(lastWriteStoreOp).getValueToStore(); - loadOp.replaceAllUsesWith(storeVal); + Value storeVal = + cast(lastWriteStoreOp).getValueToStore(); + loadOp.getValue().replaceAllUsesWith(storeVal); // Record the memref for a later sweep to optimize away. memrefsToErase.insert(loadOp.getMemRef()); // Record this to erase later. @@ -190,7 +191,7 @@ memrefsToErase.clear(); // Walk all load's and perform store to load forwarding. - f.walk([&](AffineLoadOp loadOp) { forwardStoreToLoad(loadOp); }); + f.walk([&](AffineReadOpInterface loadOp) { forwardStoreToLoad(loadOp); }); // Erase all load op's whose results were replaced with store fwd'ed ones. for (auto *loadOp : loadOpsToErase) @@ -207,7 +208,7 @@ // could still erase it if the call had no side-effects. continue; if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) { - return !isa(ownerOp); + return !isa(ownerOp); })) continue; diff --git a/mlir/test/Transforms/memref-dataflow-opt.mlir b/mlir/test/Transforms/memref-dataflow-opt.mlir --- a/mlir/test/Transforms/memref-dataflow-opt.mlir +++ b/mlir/test/Transforms/memref-dataflow-opt.mlir @@ -280,3 +280,23 @@ } return } + +// The test checks for value forwarding from vector stores to vector loads. +// The value loaded from %in can directly be stored to %out by eliminating +// store and load from %tmp. +func @vector_forwarding(%in : memref<512xf32>, %out : memref<512xf32>) { + %tmp = alloc() : memref<512xf32> + affine.for %i = 0 to 16 { + %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32> + affine.vector_store %ld0, %tmp[32*%i] : memref<512xf32>, vector<32xf32> + %ld1 = affine.vector_load %tmp[32*%i] : memref<512xf32>, vector<32xf32> + affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32> + } + return +} + +// CHECK-LABEL: func @vector_forwarding +// CHECK: affine.for %{{.*}} = 0 to 16 { +// CHECK-NEXT: %[[LDVAL:.*]] = affine.vector_load +// CHECK-NEXT: affine.vector_store %[[LDVAL]],{{.*}} +// CHECK-NEXT: }