diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1945,12 +1945,37 @@ } }; +/// Canonicalizes the pattern of the form +/// +/// %val = tensor_cast %source : : tensor to tensor<2xi32> +/// %extracted_element = extract_element %val[%c0] : tensor<2xi32> +/// +/// to +/// +/// %extracted_element = extract_element %source[%c0] : tensor +struct ExtractElementFromTensorCast + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractElementOp extract, + PatternRewriter &rewriter) const final { + auto tensorCast = extract.aggregate().getDefiningOp(); + if (!tensorCast) + return failure(); + + rewriter.replaceOpWithNewOp(extract, tensorCast.source(), + extract.getIndices()); + return success(); + } +}; + } // namespace void DynamicTensorFromElementsOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); + ExtractElementFromTensorCast, StaticDynamicTensorFromElements>( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1202,3 +1202,17 @@ return %2 : tensor } + +// ----- + +// CHECK-LABEL: func @extract_element_from_tensor_cast +// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32> +func @extract_element_from_tensor_cast(%tensor: tensor<*xf32>) -> f32 { + // CHECK-NEXT: %[[C0:.*]] = constant 0 : index + %c0 = constant 0 : index + // CHECK-NOT: tensor_cast + %casted = tensor_cast %tensor : tensor<*xf32> to tensor + // CHECK-NEXT: extract_element %[[TENSOR]][%[[C0]]] + %result = extract_element %casted[%c0] : tensor + return %result : f32 +}