diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1300,7 +1300,7 @@ static SmallVector getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { @@ -1308,42 +1308,84 @@ llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, outStaticShape[outDimIndex], - inStaticShape, inDesc, reassocation); + inStaticShape, inDesc, reassociation); })); } static SmallVector getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { DenseMap outDimToInDimMap = - getExpandedDimToCollapsedDimMap(reassocation); + getExpandedDimToCollapsedDimMap(reassociation); return llvm::to_vector<4>(llvm::map_range( llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, outStaticShape, inDesc, inStaticShape, - reassocation, outDimToInDimMap); + reassociation, outDimToInDimMap); })); } static SmallVector getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { return outStaticShape.size() < inStaticShape.size() ? getAsValues(b, loc, llvmIndexType, getCollapsedOutputShape(b, loc, llvmIndexType, - reassocation, inStaticShape, + reassociation, inStaticShape, inDesc, outStaticShape)) : getAsValues(b, loc, llvmIndexType, getExpandedOutputShape(b, loc, llvmIndexType, - reassocation, inStaticShape, + reassociation, inStaticShape, inDesc, outStaticShape)); } +static void fillInStridesForExpandedMemDescriptor( + OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, + MemRefDescriptor &dstDesc, ArrayRef reassociation) { + // See comments for computeExpandedLayoutMap for details on how the strides + // are caculated. + for (auto &en : llvm::enumerate(reassociation)) { + auto currentStrideToExpand = srcDesc.stride(b, loc, en.index()); + for (auto dstIndex : llvm::reverse(en.value())) { + dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand); + Value size = dstDesc.size(b, loc, dstIndex); + currentStrideToExpand = + b.create(loc, size, currentStrideToExpand); + } + } +} + +static void fillInStridesForCollapsedMemDescriptor( + OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, + MemRefDescriptor &dstDesc, ArrayRef reassociation) { + // See comments for computeCollapsedLayoutMap for details on how the strides + // are caculated. + auto srcShape = srcType.getShape(); + for (auto &en : llvm::enumerate(reassociation)) { + ArrayRef ref = llvm::makeArrayRef(en.value()); + while (srcShape[ref.back()] == 1 && ref.size() > 1) + ref = ref.drop_back(); + dstDesc.setStride(b, loc, en.index(), srcDesc.stride(b, loc, ref.back())); + } +} + +static void fillInDynamicStridesForMemDescriptor( + OpBuilder &b, Location loc, MemRefType srcType, MemRefType dstType, + MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, + ArrayRef reassociation) { + if (srcType.getRank() > dstType.getRank()) + fillInStridesForCollapsedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, + reassociation); + else + fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, + reassociation); +} + // ReshapeOp creates a new view descriptor of the proper rank. // For now, the only conversion supported is for target MemRef with static sizes // and strides. @@ -1360,15 +1402,6 @@ MemRefType dstType = reshapeOp.getResultType(); MemRefType srcType = reshapeOp.getSrcType(); - // The condition on the layouts can be ignored when all shapes are static. - if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { - if (!srcType.getLayout().isIdentity() || - !dstType.getLayout().isIdentity()) { - return rewriter.notifyMatchFailure( - reshapeOp, "only empty layout map is supported"); - } - } - int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(dstType, strides, offset))) { @@ -1401,14 +1434,9 @@ for (auto &en : llvm::enumerate(strides)) dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); } else { - Value c1 = rewriter.create(loc, llvmIndexType, - rewriter.getIndexAttr(1)); - Value stride = c1; - for (auto dimIndex : - llvm::reverse(llvm::seq(0, dstShape.size()))) { - dstDesc.setStride(rewriter, loc, dimIndex, stride); - stride = rewriter.create(loc, dstShape[dimIndex], stride); - } + fillInDynamicStridesForMemDescriptor(rewriter, loc, srcType, dstType, + srcDesc, dstDesc, + reshapeOp.getReassociationIndices()); } rewriter.replaceOp(reshapeOp, {dstDesc}); return success(); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1823,11 +1823,17 @@ return failure(); // The result strides are exactly the strides of the last entry of each - // reassociation. + // reassociation. The only exception is when the last entry is of size 1, the + // result stride should be of the next entry which isn't size 1. This is only + // checked statically, the dynamic case is assumed to be not 1 atm. SmallVector resultStrides; resultStrides.reserve(reassociation.size()); - for (ReassociationIndices reassoc : reassociation) - resultStrides.push_back(srcStrides[reassoc.back()]); + for (const ReassociationIndices &reassoc : reassociation) { + ArrayRef ref = llvm::makeArrayRef(reassoc); + while (srcShape[ref.back()] == 1 && ref.size() > 1) + ref = ref.drop_back(); + resultStrides.push_back(srcStrides[ref.back()]); + } // Validate that each reassociation group is contiguous. unsigned resultStrideIndex = resultStrides.size() - 1; diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -802,11 +802,10 @@ // CHECK: llvm.mlir.constant(1 : index) : i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // ----- @@ -830,14 +829,17 @@ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 - +// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 // ----- // CHECK-LABEL: func @rank_of_unranked diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -1,11 +1,14 @@ // RUN: mlir-opt %s -tensor-bufferize | FileCheck %s -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)> -// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0 + 1)> -// CHECK-DAG: #[[$MAP4:.*]] = affine_map<() -> (1)> -// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> + // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> + // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)> + // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)> + // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> + // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> (d0 * 2)> + // CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0) -> (d0 + 1)> + // CHECK-DAG: #[[$MAP6:.*]] = affine_map<() -> (1)> + // CHECK-DAG: #[[$MAP7:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)> + // CHECK-DAG: #[[$MAP8:.*]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)> // CHECK-LABEL: func @dim( // CHECK-SAME: %[[TENSOR:.*]]: tensor, @@ -338,17 +341,6 @@ return %1 : tensor } -// CHECK-LABEL: func @tensor.expand_shape_of_slice2( -// CHECK-SAME: %[[t1:.*]]: tensor<1x2xf32> -func @tensor.expand_shape_of_slice2(%t1: tensor<1x2xf32>) -> tensor<1xf32> { - // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, #[[$MAP5]]> - %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32> - // CHECK: memref.collapse_shape %{{.*}} [ - // CHECK-SAME: [0, 1]] : memref<1x1xf32, #[[$MAP5]]> into memref<1xf32> - %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32> - return %1 : tensor<1xf32> -} - // CHECK-LABEL: func @tensor.collapse_shape( // CHECK-SAME: %[[t1:.*]]: tensor<2x?x?xf32> func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor { @@ -378,9 +370,33 @@ // CHECK-LABEL: func @tensor.collapse_shape_of_slice( func @tensor.collapse_shape_of_slice(%arg0: tensor<2xi32>) -> tensor { - // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<2xi32> to memref<1xi32, #[[$MAP3]]> + // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<2xi32> to memref<1xi32, #[[$MAP5]]> %0 = tensor.extract_slice %arg0[1] [1] [1] : tensor<2xi32> to tensor<1xi32> - // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, #[[$MAP3]]> into memref + // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, #[[$MAP5]]> into memref %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor return %1 : tensor } + +// CHECK-LABEL: func @tensor.collapse_shape_of_slice2( +// CHECK-SAME: %[[t1:.*]]: tensor<1x2xf32> +func @tensor.collapse_shape_of_slice2(%t1: tensor<1x2xf32>) -> tensor<1xf32> { + // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, #[[$MAP3]]> + %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32> + // CHECK: memref.collapse_shape %{{.*}} [ + // CHECK-SAME: [0, 1]] : memref<1x1xf32, #[[$MAP3]]> into memref<1xf32, #[[$MAP4]]> + %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32> + return %1 : tensor<1xf32> +} + +// CHECK-LABEL: func @tensor.collapse_shape_of_slice3( +// CHECK-SAME: %[[t1:.*]]: tensor, +// CHECK-SAME: %[[OFFSET:.*]]: index, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor { +func @tensor.collapse_shape_of_slice3(%arg0: tensor, %offset: index, %size: index) -> tensor { + // CHECK: memref.subview %{{.*}} [1, 1] : memref to memref + %0 = tensor.extract_slice %arg0[0, %offset] [%size, 1] [1, 1] : tensor to tensor + // CHECK: memref.collapse_shape %{{.*}} [ + // CHECK-SAME: [0, 1]] : memref into memref + %ret = tensor.collapse_shape %0 [[0, 1]] : tensor into tensor + return %ret: tensor +}