diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -425,11 +425,49 @@ } }; +// Pushes the index_casts that occur before extractions to after the extract. +// This minimizes type conversion in some cases and enables the extract +// canonicalizer. This changes: +// +// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex> +// %extract = tensor.extract %cast[%index] : tensor<1xindex> +// +// to the following: +// +// %extract = tensor.extract %tensor[%index] : tensor<1xindex> +// %cast = arith.index_cast %extract : i32 to index +// +// to just %element. +struct ExtractElementFromIndexCast + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter &rewriter) const final { + Location loc = extract.getLoc(); + auto indexCast = extract.tensor().getDefiningOp(); + if (!indexCast) + return failure(); + + auto elementTy = getElementTypeOrSelf(indexCast.getIn()); + + auto newExtract = rewriter.create( + loc, elementTy, indexCast.getIn(), extract.indices()); + + rewriter.replaceOpWithNewOp(extract, extract.getType(), + newExtract); + + return success(); + } +}; + } // namespace void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); } //===----------------------------------------------------------------------===//