Index: mlir/lib/Dialect/StandardOps/IR/Ops.cpp =================================================================== --- mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3058,7 +3058,23 @@ candidateLayout = getStridedLinearLayoutMap(candidateReduced); else candidateLayout = candidateReduced.getAffineMaps().front(); - if (inferredType != candidateLayout) { + assert(inferredType.getNumResults() == 1 && + candidateLayout.getNumResults() == 1); + if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || + inferredType.getNumDims() != candidateLayout.getNumDims()) { + if (errMsg) { + llvm::raw_string_ostream os(*errMsg); + os << "inferred type: " << inferredType; + } + return SubViewVerificationResult::AffineMapMismatch; + } + // Check that the difference of the affine maps simplifies to 0. + AffineExpr diffExpr = + inferredType.getResult(0) - candidateLayout.getResult(0); + diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), + inferredType.getNumSymbols()); + auto cst = diffExpr.dyn_cast(); + if (!(cst && cst.getValue() == 0)) { if (errMsg) { llvm::raw_string_ostream os(*errMsg); os << "inferred type: " << inferredType; @@ -3344,11 +3360,29 @@ /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on /// the cast source operand type and the SubViewOp static information. This /// is the resulting type if the MemRefCastOp were folded. - Type resultType = SubViewOp::inferResultType( - castOp.source().getType().cast(), - extractFromI64ArrayAttr(subViewOp.static_offsets()), - extractFromI64ArrayAttr(subViewOp.static_sizes()), - extractFromI64ArrayAttr(subViewOp.static_strides())); + auto resultType = SubViewOp::inferResultType( + castOp.source().getType().cast(), + extractFromI64ArrayAttr(subViewOp.static_offsets()), + extractFromI64ArrayAttr(subViewOp.static_sizes()), + extractFromI64ArrayAttr(subViewOp.static_strides())) + .cast(); + uint32_t rankDiff = + subViewOp.getSourceType().getRank() - subViewOp.getType().getRank(); + if (rankDiff > 0) { + auto shape = resultType.getShape(); + auto projectedShape = shape.drop_front(rankDiff); + AffineMap map; + auto maps = resultType.getAffineMaps(); + if (!maps.empty() && maps.front()) { + auto optionalUnusedDimsMask = + computeRankReductionMask(shape, projectedShape); + llvm::SmallDenseSet dimsToProject = + optionalUnusedDimsMask.getValue(); + map = getProjectedMap(maps.front(), dimsToProject); + } + resultType = MemRefType::get(projectedShape, resultType.getElementType(), + map, resultType.getMemorySpace()); + } Value newSubView = rewriter.create( subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), Index: mlir/test/Dialect/Standard/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Standard/canonicalize.mlir +++ mlir/test/Dialect/Standard/canonicalize.mlir @@ -143,3 +143,17 @@ %1 = tensor_to_memref %0 : memref return %1 : memref } + +// CHECK-LABEL: func @subview_of_memcast +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> +// CHECK: %[[S:.+]] = subview %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> +// CHECK: %[[M:.+]] = memref_cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> +// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> +func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> + memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ + %0 = memref_cast %arg : memref<4x6x16x32xi8> to memref + %1 = subview %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : + memref to + memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> + return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> +}