diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2433,9 +2433,12 @@ // Get result memref type. auto memrefType = viewOp.getType(); - if (memrefType.getAffineMaps().size() != 1) + if (memrefType.getAffineMaps().size() > 1) return matchFailure(); - auto map = memrefType.getAffineMaps()[0]; + auto map = memrefType.getAffineMaps().empty() + ? AffineMap::getMultiDimIdentityMap(memrefType.getRank(), + rewriter.getContext()) + : memrefType.getAffineMaps()[0]; // Get offset from old memref view type 'memRefType'. int64_t oldOffset; diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -692,6 +692,7 @@ // CHECK-DAG: #[[VIEW_MAP2:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1 + 15)> // CHECK-DAG: #[[VIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 7 + d2)> // CHECK-DAG: #[[VIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 15)> +// CHECK-DAG: #[[VIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 7 + d1)> // CHECK-LABEL: func @view func @view(%arg0 : index) { @@ -736,7 +737,7 @@ load %5[%c0, %c0] : memref // Test: folding static alloc and memref_cast into a view. - // CHECK: std.view %[[ALLOC_MEM]][][%c15, %c7] : memref<2048xi8> to memref + // CHECK: std.view %[[ALLOC_MEM]][][] : memref<2048xi8> to memref<15x7xf32, #[[VIEW_MAP5]]> %6 = memref_cast %0 : memref<2048xi8> to memref %7 = view %6[%c15][%c7] : memref to memref load %7[%c0, %c0] : memref