diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -361,16 +361,13 @@ } OpFoldResult ExtractOp::fold(ArrayRef operands) { - // The tensor operand must be a known constant. - Attribute tensor = operands.front(); - if (!tensor) - return {}; // If this is a splat elements attribute, simply return the value. All of the // elements of a splat attribute are the same. - if (auto splatTensor = tensor.dyn_cast()) - return splatTensor.getSplatValue(); + if (Attribute tensor = operands.front()) + if (auto splatTensor = tensor.dyn_cast()) + return splatTensor.getSplatValue(); - // Otherwise, collect the constant indices into the tensor. + // Collect the constant indices into the tensor. SmallVector indices; for (Attribute indice : llvm::drop_begin(operands, 1)) { if (!indice || !indice.isa()) @@ -378,10 +375,32 @@ indices.push_back(indice.cast().getInt()); } + // Fold extract(from_elements(...)). + if (auto fromElementsOp = this->tensor().getDefiningOp()) { + auto tensorType = fromElementsOp.getType().cast(); + auto rank = tensorType.getRank(); + assert(indices.size() == tensorType.getRank() && "rank mismatch"); + int flatIndex = 0; + int stride = 1; + for (int i = rank - 1; i >= 0; --i) { + if (i < rank - 1) + stride *= tensorType.getDimSize(i); + flatIndex += indices[i] * stride; + } + // Prevent out of bounds accesses. This can happen in invalid code that will + // never execute. + if (fromElementsOp.elements().size() <= flatIndex || flatIndex < 0) + return {}; + return fromElementsOp.elements()[flatIndex]; + } + // If this is an elements attribute, query the value at the given indices. - auto elementsAttr = tensor.dyn_cast(); - if (elementsAttr && elementsAttr.isValidIndex(indices)) - return elementsAttr.getValues()[indices]; + if (Attribute tensor = operands.front()) { + auto elementsAttr = tensor.dyn_cast(); + if (elementsAttr && elementsAttr.isValidIndex(indices)) + return elementsAttr.getValues()[indices]; + } + return {}; } @@ -411,47 +430,6 @@ namespace { -// Canonicalizes the pattern of the form -// -// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> -// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32> -// -// to just %element. -struct ExtractElementFromTensorFromElements - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractOp extract, - PatternRewriter &rewriter) const final { - auto tensorFromElements = extract.tensor().getDefiningOp(); - if (!tensorFromElements) - return failure(); - auto tensorType = tensorFromElements.getType().cast(); - auto rank = tensorType.getRank(); - if (rank == 0) { - rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); - return success(); - } - SmallVector indices(rank); - int64_t flatIndex = 0; - int64_t stride = 1; - for (int i = rank - 1; i >= 0; --i) { - APInt index; - if (!matchPattern(extract.indices()[i], m_ConstantInt(&index))) - return failure(); - if (i < rank - 1) - stride *= tensorType.getDimSize(i); - flatIndex += index.getSExtValue() * stride; - } - // Prevent out of bounds accesses. This can happen in invalid code that will - // never execute. - if (tensorFromElements->getNumOperands() <= flatIndex || flatIndex < 0) - return failure(); - rewriter.replaceOp(extract, tensorFromElements.getOperand(flatIndex)); - return success(); - } -}; - // Pushes the index_casts that occur before extractions to after the extract. // This minimizes type conversion in some cases and enables the extract // canonicalizer. This changes: @@ -494,9 +472,7 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add( - context); + results.add(context); } //===----------------------------------------------------------------------===//