diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" @@ -839,6 +840,40 @@ } }; +/// Canonicalizes an extract from a unit-dim-reducing `tensor.extract_slice` op: +/// +/// %val = tensor.extract_slice %source : : tensor<1x5xi32> to tensor<5xi32> +/// %extracted_element = tensor.extract %val[%c2] : tensor +/// +/// to +/// +/// %extracted_element = tensor.extract %source[%c0, %c2] : tensor<1x5xi32> +struct ExtractFromCastLikeSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter &rewriter) const final { + auto sliceOp = extract.getTensor().getDefiningOp(); + if (!sliceOp) + return failure(); + + RankedTensorType sourceType = sliceOp.getSourceType(); + if (!isCastLikeExtractSliceOp(sliceOp)) + return failure(); + + llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); + SmallVector indices; + Value zero = rewriter.create(extract.getLoc(), 0); + unsigned extractIdx = 0; + for (int64_t i = 0, e = sourceType.getRank(); i < e; i++) + indices.push_back(droppedDims[i] ? zero + : extract.getIndices()[extractIdx++]); + rewriter.replaceOpWithNewOp(extract, sliceOp.getSource(), + indices); + return success(); + } +}; + } // namespace void ExtractOp::getAsmResultNames( @@ -902,7 +937,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -149,6 +149,36 @@ return %result : f32 } + +// ----- + +// CHECK-LABEL: func @extract_from_cast_like_slice +// CHECK-SAME: %[[ARG:.+]]: tensor<1x5xf32> +func.func @extract_from_cast_like_slice(%tensor: tensor<1x5xf32>) -> f32 { + %c2 = arith.constant 2 : index + %0 = tensor.extract_slice %tensor[0, 0][1, 5][1, 1] : tensor<1x5xf32> to tensor<5xf32> + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + // CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG]][%[[C0]], %[[C2]]] + %1 = tensor.extract %0[%c2] : tensor<5xf32> + // CHECK: return %[[EXTRACTED]] + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: func @dont_fold_extract_from_non_cast_like_slice +// CHECK-SAME: %[[ARG:.+]]: tensor<1x5xf32> +func.func @dont_fold_extract_from_non_cast_like_slice(%tensor: tensor<1x5xf32>) -> f32 { + %c2 = arith.constant 2 : index + // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG]] + %0 = tensor.extract_slice %tensor[0, 0][1, 3][1, 1] : tensor<1x5xf32> to tensor<3xf32> + // CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[SLICE]] + %1 = tensor.extract %0[%c2] : tensor<3xf32> + // CHECK: return %[[EXTRACTED]] + return %1 : f32 +} + // ----- // CHECK-LABEL: func @extract_from_tensor.from_elements @@ -1768,7 +1798,7 @@ // Chain: NC -> NCnc -> NCnc -> NC // CHECK: func.func @unpack_pack( -// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>, +// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>, // CHECK: return %[[T]] : tensor<128x128xf32> func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> { %tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>