Index: mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td =================================================================== --- mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td +++ mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td @@ -81,6 +81,17 @@ op.getAffineMapAttr()}; }] >, + InterfaceMethod< + /*desc=*/"Returns the underlying value.", + /*retTy=*/"Value", + /*methodName=*/"getValue", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + ConcreteOp op = cast(this->getOperation()); + return op; + }] + >, ]; } @@ -150,6 +161,17 @@ op.getAffineMapAttr()}; }] >, + InterfaceMethod< + /*desc=*/"Returns the value to store.", + /*retTy=*/"Value", + /*methodName=*/"getValueToStore", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + ConcreteOp op = cast(this->getOperation()); + return op.getOperand(op.getStoredValOperandIndex()); + }] + >, ]; } Index: mlir/include/mlir/Dialect/Affine/IR/AffineOps.td =================================================================== --- mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -725,8 +725,8 @@ Affine_Op])> { code extraClassDeclarationBase = [{ - /// Get value to be stored by store operation. - Value getValueToStore() { return getOperand(0); } + /// Returns the operand index of the stored value. + unsigned getStoredValOperandIndex() { return 0; } /// Returns the operand index of the memref. unsigned getMemRefOperandIndex() { return 1; } Index: mlir/lib/Transforms/MemRefDataFlowOpt.cpp =================================================================== --- mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -63,7 +63,7 @@ struct MemRefDataFlowOpt : public MemRefDataFlowOptBase { void runOnFunction() override; - void forwardStoreToLoad(AffineLoadOp loadOp); + void forwardStoreToLoad(AffineReadOpInterface loadOp); // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; @@ -84,14 +84,14 @@ // This is a straightforward implementation not optimized for speed. Optimize // if needed. -void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) { +void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) { // First pass over the use list to get the minimum number of surrounding // loops common between the load op and the store op, with min taken across // all store ops. SmallVector storeOps; unsigned minSurroundingLoops = getNestingDepth(loadOp); for (auto *user : loadOp.getMemRef().getUsers()) { - auto storeOp = dyn_cast(user); + auto storeOp = dyn_cast(user); if (!storeOp) continue; unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp); @@ -167,8 +167,8 @@ return; // Perform the actual store to load forwarding. - Value storeVal = cast(lastWriteStoreOp).getValueToStore(); - loadOp.replaceAllUsesWith(storeVal); + Value storeVal = cast(lastWriteStoreOp).getValueToStore(); + loadOp.getValue().replaceAllUsesWith(storeVal); // Record the memref for a later sweep to optimize away. memrefsToErase.insert(loadOp.getMemRef()); // Record this to erase later. @@ -190,7 +190,7 @@ memrefsToErase.clear(); // Walk all load's and perform store to load forwarding. - f.walk([&](AffineLoadOp loadOp) { forwardStoreToLoad(loadOp); }); + f.walk([&](AffineReadOpInterface loadOp) { forwardStoreToLoad(loadOp); }); // Erase all load op's whose results were replaced with store fwd'ed ones. for (auto *loadOp : loadOpsToErase) @@ -207,7 +207,7 @@ // could still erase it if the call had no side-effects. continue; if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) { - return !isa(ownerOp); + return !isa(ownerOp); })) continue; Index: mlir/test/Transforms/memref-vector-dataflow.mlir =================================================================== --- /dev/null +++ mlir/test/Transforms/memref-vector-dataflow.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt %s -memref-dataflow-opt -split-input-file | FileCheck %s + + +func @eltwise_filter_2d(%ifm0 : memref<10x20x30x512xf32> {tpc.tensor}, // Reused for eltwise output. + %ifm1 : memref<10x20x30x512xf32> {tpc.tensor}, + %filter : memref<5x5x512xf32> {tpc.tensor}, + %ofm : memref<10x20x30x512xf32> {tpc.tensor}, + %auxBias : memref<512xf32> {tpc.tensor}, + %padW : index, + %padH : index + //%kernelW : index, -> 5 + //%kernelH : index, -> 5 + //%strideW : index, -> 2 + //%strideH : index, -> 2 + //%dilationW : index, -> 3 + //%dilationH : index -> 3 + ) { + // Dims + //const int depth = 0; + //const int width = 1; + //const int height = 2; + //const int batch = 3; + + affine.for %d = 0 to 16 { + affine.for %b = 0 to 10 { + affine.for %h = 0 to 20 { + affine.for %w = 0 to 30 { + %lhs = affine.vector_load %ifm0[%b, %h, %w, 32*%d] : memref<10x20x30x512xf32>, vector<32xf32> + %rhs = affine.vector_load %ifm1[%b, %h, %w, 32*%d] : memref<10x20x30x512xf32>, vector<32xf32> + %add = addf %lhs, %rhs : vector<32xf32> + affine.vector_store %add, %ifm0[%b, %h, %w, 32*%d] : memref<10x20x30x512xf32>, vector<32xf32> + %lhs1 = affine.vector_load %ifm0[%b, %h, %w, 32*%d] : memref<10x20x30x512xf32>, vector<32xf32> + %add1 = addf %lhs1, %rhs : vector<32xf32> + affine.vector_store %add1, %ofm[%b, %h, %w, 32*%d] : memref<10x20x30x512xf32>, vector<32xf32> + } + } + } + } + + return +} + +// CHECK: affine.for %{{.*}} = 0 to 16 +// CHECK-NEXT: affine.for %{{.*}} = 0 to 10 +// CHECK-NEXT: affine.for %{{.*}} = 0 to 20 +// CHECK-NEXT: affine.for %{{.*}} = 0 to 30 +// CHECK: affine.vector_load +// CHECK: affine.vector_load +// CHECK: addf +// CHECK: affine.vector_store +// CHECK: addf +// CHECK: affine.vector_store +// CHECK: } +