diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp @@ -10,13 +10,14 @@ // loading/storing from/to the original memref. // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallBitVector.h" @@ -212,9 +213,20 @@ if (!subViewOp) return failure(); + ValueRange indices = loadOp.indices(); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineLoadOp = dyn_cast(loadOp.getOperation())) { + auto expandedIndices = + expandAffineMap(rewriter, loadOp.getLoc(), affineLoadOp.getAffineMap(), + affineLoadOp.indices()); + if (!expandedIndices) + return failure(); + indices = expandedIndices.getValue(); + } SmallVector sourceIndices; - if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, - loadOp.indices(), sourceIndices))) + if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, indices, + sourceIndices))) return failure(); replaceOp(loadOp, subViewOp, sourceIndices, rewriter); @@ -230,9 +242,20 @@ if (!subViewOp) return failure(); + ValueRange indices = storeOp.indices(); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineStoreOp = dyn_cast(storeOp.getOperation())) { + auto expandedIndices = + expandAffineMap(rewriter, storeOp.getLoc(), + affineStoreOp.getAffineMap(), affineStoreOp.indices()); + if (!expandedIndices) + return failure(); + indices = expandedIndices.getValue(); + } SmallVector sourceIndices; if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, - storeOp.indices(), sourceIndices))) + indices, sourceIndices))) return failure(); replaceOp(storeOp, subViewOp, sourceIndices, rewriter); diff --git a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir @@ -262,6 +262,7 @@ func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 { %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> %1 = affine.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> + // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index // CHECK-NEXT: affine.apply // CHECK-NEXT: affine.apply // CHECK-NEXT: affine.load @@ -269,6 +270,11 @@ // CHECK-NEXT: affine.apply // CHECK-NEXT: affine.apply // CHECK-NEXT: affine.store + // Fewer operands than the memref rank. + // CHECK-NEXT: affine.apply + // CHECK-NEXT: affine.apply #{{.*}}(%[[C0]])[%{{.*}}] + // CHECK-NEXT: affine.store + affine.store %1, %0[%arg3, 0] : memref<4x4xf32, offset:?, strides: [64, 3]> // CHECK-NEXT: return return %1 : f32 }