diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -425,11 +425,51 @@ } }; +// Pushes the index_casts that occur before extractions to after the extract. +// This minimizes type conversion in some cases and enables the extract +// canonicalizer. This changes: +// +// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex> +// %extract = tensor.extract %cast[%index] : tensor<1xindex> +// +// to the following: +// +// %extract = tensor.extract %tensor[%index] : tensor<1xindex> +// %cast = arith.index_cast %extract : i32 to index +// +// to just %element. +// +// Consider expanding this to a template and handle all tensor cast operations. +struct ExtractElementFromIndexCast + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter &rewriter) const final { + Location loc = extract.getLoc(); + auto indexCast = extract.tensor().getDefiningOp(); + if (!indexCast) + return failure(); + + Type elementTy = getElementTypeOrSelf(indexCast.getIn()); + + auto newExtract = rewriter.create( + loc, elementTy, indexCast.getIn(), extract.indices()); + + rewriter.replaceOpWithNewOp(extract, extract.getType(), + newExtract); + + return success(); + } +}; + } // namespace void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1200,3 +1200,17 @@ %1 = tensor.expand_shape %0 [] : tensor into tensor<1xi32> return %1 : tensor<1xi32> } + +// ----- + +// CHECK-LABEL: func @propogate_index_cast +func @propogate_index_cast(%arg0: tensor<1xi32>) -> index { + // CHECK: %[[IDX:.+]] = arith.constant 0 + // CHECK: %[[EXT:.+]] = tensor.extract %arg0[%[[IDX]]] : tensor<1xi32> + // CHECK: %[[CAST:.+]] = arith.index_cast %[[EXT]] + // CHECK: return %[[CAST]] : index + %c0 = arith.constant 0 : index + %0 = arith.index_cast %arg0 : tensor<1xi32> to tensor<1xindex> + %1 = tensor.extract %0[%c0] : tensor<1xindex> + return %1 : index +}