diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -24,8 +24,72 @@ using namespace mlir; using namespace mlir::tensor; +// Implements backtracking to traverse indices of the output buffer while +// iterating over op.elements(). +static Value createInserts(RewriterBase &rewriter, Location loc, int dim, + Value destination, ArrayRef shape, + ArrayRef constants, + OperandRange::iterator &elementIt, + SmallVectorImpl &indices) { + if (dim == static_cast(shape.size()) - 1) { + for (int i = 0; i < shape.back(); ++i) { + indices.back() = constants[i]; + destination = rewriter.create(loc, *elementIt, + destination, indices); + ++elementIt; + } + return destination; + } + for (int i = 0; i < shape[dim]; ++i) { + indices[dim] = constants[i]; + destination = createInserts(rewriter, loc, dim + 1, destination, shape, + constants, elementIt, indices); + } + return destination; +} + namespace { +/// Lower tensor.from_elements to a sequence of chained tensor.insert. +struct FromElementsOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FromElementsOp elementsOp, + PatternRewriter &rewriter) const override { + Location loc = elementsOp.getLoc(); + RankedTensorType tensorType = elementsOp.getType().cast(); + auto shape = tensorType.getShape(); + + // Create tensor.empty. + auto emptyOp = rewriter.create(loc, tensorType, ValueRange()); + + // Case: tensor. + if (shape.empty()) { + rewriter.replaceOpWithNewOp( + elementsOp, elementsOp.getElements().front(), emptyOp.getResult(), + ValueRange()); + return success(); + } + + // Create constants for the range of possible indices [0, max{shape_i}). + auto maxDim = *std::max_element(shape.begin(), shape.end()); + SmallVector constants; + constants.reserve(maxDim); + for (int i = 0; i < maxDim; ++i) + constants.push_back(rewriter.create(loc, i)); + + // Traverse all elements and create tensor.insert ops. + auto elementIt = elementsOp.getElements().begin(); + SmallVector indices(tensorType.getRank(), constants[0]); + Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(), + shape, constants, elementIt, indices); + + // Replace tensor.from_elements. + rewriter.replaceOp(elementsOp, result); + return success(); + } +}; + /// Lower tensor.generate to linalg.generic. struct GenerateOpConverter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -172,5 +236,6 @@ void linalg::populateConvertToDestinationStylePatterns( RewritePatternSet &patterns) { - patterns.insert(patterns.getContext()); + patterns.insert( + patterns.getContext()); } diff --git a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir b/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir --- a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir +++ b/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir @@ -1,5 +1,53 @@ // RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-convert-to-destination-style-patterns -canonicalize %s | FileCheck %s +// CHECK-LABEL: func @tensor_from_elements_0d( +// CHECK-SAME: %[[arg0:.*]]: index +// CHECK: %[[empty:.*]] = tensor.empty() : tensor +// CHECK: %[[insert:.*]] = tensor.insert %[[arg0]] into %[[empty]][] +// CHECK: return %[[insert]] +func.func @tensor_from_elements_0d(%arg0: index) -> tensor { + %0 = tensor.from_elements %arg0 : tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @tensor_from_elements_1d( +// CHECK-SAME: %[[arg0:.*]]: index, %[[arg1:.*]]: index +// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2xindex> +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[insert:.*]] = tensor.insert %[[arg0]] into %[[empty]][%[[c0]]] +// CHECK: %[[insert2:.*]] = tensor.insert %[[arg1]] into %[[insert]][%[[c1]]] +// CHECK: return %[[insert2]] +func.func @tensor_from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> { + %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex> + return %0 : tensor<2xindex> +} + +// ----- + +// CHECK-LABEL: func @tensor_from_elements_2d( +// CHECK-SAME: %[[arg0:.*]]: index, %[[arg1:.*]]: index +// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<3x2xindex> +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index +// CHECK: %[[insert0:.*]] = tensor.insert %[[arg0]] into %[[empty]][%[[c0]], %[[c0]]] +// CHECK: %[[insert1:.*]] = tensor.insert %[[arg1]] into %[[insert0]][%[[c0]], %[[c1]]] +// CHECK: %[[insert2:.*]] = tensor.insert %[[arg0]] into %[[insert1]][%[[c1]], %[[c0]]] +// CHECK: %[[insert3:.*]] = tensor.insert %[[arg1]] into %[[insert2]][%[[c1]], %[[c1]]] +// CHECK: %[[insert4:.*]] = tensor.insert %[[arg0]] into %[[insert3]][%[[c2]], %[[c0]]] +// CHECK: %[[insert5:.*]] = tensor.insert %[[arg1]] into %[[insert4]][%[[c2]], %[[c1]]] +// CHECK: return %[[insert5]] +func.func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> { + %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1 + : tensor<3x2xindex> + return %0 : tensor<3x2xindex> +} + +// ----- + // CHECK: #[[$map:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @tensor_generate( // CHECK-SAME: %[[s1:.*]]: index, %[[s2:.*]]: index