diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1758,11 +1758,12 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, memref::CopyOp copyOp) { - auto srcType = cast(copyOp.getSource().getType()); auto dstType = cast(copyOp.getTarget().getType()); if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) return failure(); + if (srcType.getNumElements() == 0 || dstType.getNumElements() == 0) + return failure(); auto srcElementType = getElementTypeOrSelf(srcType); auto dstElementType = getElementTypeOrSelf(dstType); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -879,16 +879,37 @@ return success(); } }; + +/// Fold memref.copy that copies empty buffers. +struct FoldSizeZero : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CopyOp copyOp, + PatternRewriter &rewriter) const override { + auto srcType = dyn_cast(copyOp.getSource().getType()); + auto dstType = dyn_cast(copyOp.getTarget().getType()); + bool zeroCopy = false; + if (srcType) + zeroCopy = llvm::find(srcType.getShape(), 0) != srcType.getShape().end(); + if (dstType) + zeroCopy |= llvm::find(dstType.getShape(), 0) != dstType.getShape().end(); + if (zeroCopy) { + rewriter.eraseOp(copyOp); + return success(); + } + return failure(); + } +}; } // namespace void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } LogicalResult CopyOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { - /// copy(memrefcast) -> copy + // copy(memrefcast) -> copy bool folded = false; Operation *op = *this; for (OpOperand &operand : op->getOpOperands()) { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -173,9 +173,13 @@ for (int64_t i = 0, e = type.getRank(); i < e; ++i) { int64_t offset = mixedOffsets[i]; int64_t size = mixedSizes[i]; - if (!type.isDynamicDim(i) && !ShapedType::isDynamic(size)) + if (!type.isDynamicDim(i) && !ShapedType::isDynamic(size)) { if (!ShapedType::isDynamic(offset) && offset + size > type.getDimSize(i)) return op->emitOpError("dimension #") << i << " runs out of bounds"; + // Taking a non-zero slice from a dimension of size zero is invalid. + if (size > 0 && type.getDimSize(i) == 0) + return op->emitOpError("offset #") << i << " runs out of bounds"; + } } // No negative offsets. @@ -872,6 +876,9 @@ auto tensorType = llvm::cast(getTensor().getType()); if (tensorType.getRank() != static_cast(getIndices().size())) return emitOpError("incorrect number of indices for extract_element"); + for (const auto &it : llvm::enumerate(getTensor().getType().getShape())) + if (it.value() == 0) + return emitOpError("index #") << it.index() << " is out of bounds"; return success(); } @@ -1099,6 +1106,9 @@ auto destType = llvm::cast(getDest().getType()); if (destType.getRank() != static_cast(getIndices().size())) return emitOpError("incorrect number of indices"); + for (const auto &it : llvm::enumerate(getDest().getType().getShape())) + if (it.value() == 0) + return emitOpError("index #") << it.index() << " is out of bounds"; return success(); } @@ -2304,6 +2314,9 @@ getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) return this->getSource(); + if (llvm::find(getSourceType().getShape(), 0) != + getSourceType().getShape().end()) + return this->getDest(); if (succeeded(foldInsertAfterInsertSlice(*this))) return getResult(); if (auto result = foldInsertAfterExtractSlice(*this)) diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir --- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir @@ -178,3 +178,103 @@ // expected-error @below {{op requires isolated-from-above targets}} %2 = transform.structured.vectorize %0 : (!transform.any_op) -> !transform.any_op } + +// ----- + +// Linalg ops with zero dimensions do not vectorize. + +// CHECK-LABEL: @vectorize_matmul_zero +func.func @vectorize_matmul_zero(%arg0: tensor<0x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<0x25xf32>) -> tensor<0x25xf32> { + // CHECK: linalg.matmul + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<0x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<0x25xf32>) -> tensor<0x25xf32> + func.return %0 : tensor<0x25xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op +} + +// ----- + +// Linalg ops with zero dimensions do not vectorize. + +// CHECK-LABEL: @vectorize_copy_memref_zero +func.func @vectorize_copy_memref_zero(%arg0: memref<100x0xf32>, + %arg1: memref<100x0xf32>) { + // CHECK: linalg.copy + linalg.copy ins(%arg0 : memref<100x0xf32>) outs(%arg1 : memref<100x0xf32>) + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op +} + +// ----- + +// Memref copies with zero dimensions do not vectorize. + +// CHECK-LABEL: @vectorize_copy_memref_zero +func.func @vectorize_copy_memref_zero(%arg0: memref<100x0xf32>, + %arg1: memref<100x0xf32>) { + // CHECK: memref.copy + memref.copy %arg0, %arg1 : memref<100x0xf32> to memref<100x0xf32> + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op +} + +// ----- + +// CHECK-LABEL: @vectorize_tensor_pad_zero +// CHECK-NEXT: return {{.*}} : tensor<0xf32> +func.func @vectorize_tensor_pad_zero(%arg0: tensor<0xf32>) -> tensor<0xf32> { + %cst = arith.constant 0.0 : f32 + %0 = tensor.pad %arg0 low[0] high[0] { + ^bb0(%p: index): + tensor.yield %cst : f32 + } : tensor<0xf32> to tensor<0xf32> + func.return %0 : tensor<0xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op +} + +// ----- + +// CHECK-LABEL: @vectorize_tensor_pad_zero_2( +// CHECK-SAME: %[[arg1:.*]]: tensor<16xf32>, +// CHECK-NEXT: return %[[arg1]] : tensor<16xf32> +func.func @vectorize_tensor_pad_zero_2(%arg1: tensor<16xf32>, %arg0: tensor<0xf32>) -> tensor<16xf32> { + %cst = arith.constant 0.0 : f32 + %0 = tensor.pad %arg0 low[0] high[0] { + ^bb0(%p: index): + tensor.yield %cst : f32 + } : tensor<0xf32> to tensor<0xf32> + %1 = tensor.insert_slice %0 into %arg1 [0][0][1] : tensor<0xf32> into tensor<16xf32> + func.return %1 : tensor<16xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op +} diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -588,6 +588,16 @@ // ----- +func.func @zero_copy(%m1: memref, %m2: memref) { + memref.copy %m1, %m2 : memref to memref + return +} + +// CHECK-LABEL: func @zero_copy +// CHECK-NEXT: return + +// ----- + func.func @scopeMerge() { memref.alloca_scope { %cnt = "test.count"() : () -> index diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -24,6 +24,14 @@ // ----- +func.func @extract_zero(%arg0: tensor<0xf32>, %arg1: index) { + // expected-error@+1 {{index #0 is out of bounds}} + %0 = tensor.extract %arg0[%arg1] : tensor<0xf32> + return +} + +// ----- + func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor) { // expected-error@+1 {{incorrect number of indices}} %0 = tensor.insert %arg0 into %arg1[] : tensor @@ -32,6 +40,14 @@ // ----- +func.func @insert_zero(%arg0: tensor<0xf32>, %arg1: f32, %arg2: index) { + // expected-error@+1 {{index #0 is out of bounds}} + %0 = tensor.insert %arg1 into %arg0[%arg2] : tensor<0xf32> + return +} + +// ----- + func.func @tensor.from_elements_wrong_result_type() { // expected-error@+2 {{'tensor.from_elements' invalid kind of type specified}} %c0 = arith.constant 0 : i32 @@ -285,6 +301,16 @@ // ----- +func.func @insert_slice_zero(%t1: tensor<4x4xf32>, %t2: tensor<8x16x0xf32>, %pos: index) { + // expected-error @+1 {{offset #2 runs out of bounds}} + %0 = tensor.insert_slice %t1 into %t2[0, 1, %pos][4, 4, 1][1, 1, 1] + : tensor<4x4xf32> into tensor<8x16x0xf32> + + return +} + +// ----- + func.func @insert_slice_negative(%t1: tensor<4x4xf32>, %t2: tensor<8x16x2xf32>) { // expected-error @+1 {{offset #2 is negative}} %0 = tensor.insert_slice %t1 into %t2[0, 0, -1][4, 4, 1][1, 1, 1] diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -114,7 +114,7 @@ // ----- // CHECK-LABEL: func @slice({{.*}}) { -func.func @slice(%t: tensor<8x16x4xf32>, %idx : index) { +func.func @slice(%t: tensor<8x16x4xf32>, %idx: index, %t2: tensor<0xf32>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -133,6 +133,10 @@ %3 = tensor.extract_slice %t[0, 2, 0][4, 1, 4][1, 1, 1] : tensor<8x16x4xf32> to tensor<4x4xf32> + // CHECK: tensor.extract_slice + // CHECK-SAME: tensor<0xf32> to tensor<0xf32> + %4 = tensor.extract_slice %t2[0][0][1] : tensor<0xf32> to tensor<0xf32> + return }