diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1861,8 +1861,12 @@ AffineMap candidateLayout; if (candidateReduced.getAffineMaps().empty()) candidateLayout = getStridedLinearLayoutMap(candidateReduced); - else - candidateLayout = candidateReduced.getAffineMaps().front(); + else { + auto affineMaps = candidateReduced.getAffineMaps(); + candidateLayout = affineMaps.back(); + for (size_t i = affineMaps.size() - 1; i > 0; --i) + candidateLayout = candidateLayout.compose(affineMaps[i - 1]); + } assert(inferredType.getNumResults() == 1 && candidateLayout.getNumResults() == 1); if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -211,3 +211,15 @@ // CHECK: func @collapse_shape_to_dynamic // CHECK: memref.collapse_shape // CHECK-SAME: [0], [1], [2, 3, 4] + +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> +#map2 = affine_map<(d0, d1, d2, d3) -> (64 * d0 + 16 * d1 + 2 * d2 + d3)> +#map3 = affine_map<(d0, d1, d2, d3) -> (64 * d0 + 16 * d1 + 2 * d2 + d3 + 39)> + +// CHECK-LABEL: func @subview_with_affine_maps +func @subview_with_affine_maps(%arg0: memref<1x2x4x8xf32, #map1, #map2>) + -> memref<1x1x2x3xf32, #map1, #map3> { + %0 = memref.subview %arg0 [0, 1, 2, 3][1, 1, 2, 3][1, 1, 1, 1] : + memref<1x2x4x8xf32, #map1, #map2> to memref<1x1x2x3xf32, #map1, #map3> + return %0 : memref<1x1x2x3xf32, #map1, #map3> +}