diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -216,23 +216,20 @@ "$_self.cast().getElementType()">]> { let summary = "element extraction operation"; let description = [{ - The `tensor.extract` op reads a tensor and returns one - element from it specified by an index list. The output of the op is a - new value with the same type as the elements of the tensor. The - arity of indices must match the rank of the accessed value (i.e., if a - tensor is of rank 3, then 3 indices are required for the extract. The - indices should all be of `index` type. + The `tensor.extract` op reads a ranked tensor and returns one element as + specified by the given indices. The result of the op is a value with the + same type as the elements of the tensor. The arity of indices must match + the rank of the accessed value. All indices should all be of `index` type. Example: ```mlir %4 = tensor.extract %t[%1, %2] : tensor<4x4xi32> %5 = tensor.extract %rt[%1, %2] : tensor - %6 = tensor.extract %ut[%1, %2] : tensor<*xi32> ``` }]; - let arguments = (ins AnyTensor:$tensor, Variadic:$indices); + let arguments = (ins AnyRankedTensor:$tensor, Variadic:$indices); let results = (outs AnyType:$result); let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)"; @@ -242,6 +239,7 @@ build($_builder, $_state, resType, tensor, indices); }]>]; + let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -684,35 +682,33 @@ Pure, TypesMatchWith<"result type matches type of dest", "dest", "result", - "$_self.cast()">, + "$_self">, TypesMatchWith<"scalar type matches element type of dest", "dest", "scalar", "$_self.cast().getElementType()">]> { let summary = "element insertion operation"; let description = [{ - The `tensor.insert` op writes a tensor into a tensor `dest`as specified by - the operation's indices. + The `tensor.insert` op inserts a scalar into a ranked tensor `dest` as + specified by the operation's indices. - It returns a copy of `dest` with the proper slice updated with the value + It returns a copy of `dest` with the indexed position updated to the value of `scalar`. - The arity of indices must match the rank of the tensor `dest` (i.e., if a - tensor is of rank 3, then 3 indices are required for the extract. The - indices should all be of `index` type. + The arity of `indices `must match the rank of the tensor `dest`. All + indices should be of `index` type. Example: ```mlir %4 = tensor.insert %t into %dest[%1, %2] : tensor<4x4xi32> %5 = tensor.insert %rt into %dest[%1, %2] : tensor - %6 = tensor.insert %ut into %dest[%1, %2] : tensor<*xi32> ``` }]; let arguments = (ins AnyType:$scalar, - AnyTensor:$dest, + AnyRankedTensor:$dest, Variadic:$indices); - let results = (outs AnyTensor:$result); + let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ $scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest) }]; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -773,6 +773,34 @@ // ExtractOp //===----------------------------------------------------------------------===// +namespace { + +/// Canonicalizes the pattern of the form +/// +/// %val = tensor.cast %source : : tensor to tensor<2xi32> +/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32> +/// +/// to +/// +/// %extracted_element = tensor.extract %source[%c0] : tensor +struct ExtractFromTensorCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter &rewriter) const final { + auto tensorCast = extract.getTensor().getDefiningOp(); + if (!tensorCast) + return failure(); + if (!tensorCast.getSource().getType().isa()) + return failure(); + rewriter.replaceOpWithNewOp( + extract, tensorCast.getSource(), extract.getIndices()); + return success(); + } +}; + +} // namespace + void ExtractOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), "extracted"); @@ -780,10 +808,9 @@ LogicalResult ExtractOp::verify() { // Verify the # indices match if we have a ranked type. - if (auto tensorType = getTensor().getType().dyn_cast()) - if (tensorType.getRank() != static_cast(getIndices().size())) - return emitOpError("incorrect number of indices for extract_element"); - + auto tensorType = getTensor().getType().cast(); + if (tensorType.getRank() != static_cast(getIndices().size())) + return emitOpError("incorrect number of indices for extract_element"); return success(); } @@ -833,6 +860,11 @@ return {}; } +void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // FromElementsOp //===----------------------------------------------------------------------===// @@ -1009,9 +1041,9 @@ LogicalResult InsertOp::verify() { // Verify the # indices match if we have a ranked type. - if (auto destType = getDest().getType().dyn_cast()) - if (destType.getRank() != static_cast(getIndices().size())) - return emitOpError("incorrect number of indices"); + auto destType = getDest().getType().cast(); + if (destType.getRank() != static_cast(getIndices().size())) + return emitOpError("incorrect number of indices"); return success(); } @@ -1181,36 +1213,12 @@ } }; -/// Canonicalizes the pattern of the form -/// -/// %val = tensor.cast %source : : tensor to tensor<2xi32> -/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32> -/// -/// to -/// -/// %extracted_element = tensor.extract %source[%c0] : tensor -struct ExtractFromTensorCast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractOp extract, - PatternRewriter &rewriter) const final { - auto tensorCast = extract.getTensor().getDefiningOp(); - if (!tensorCast) - return failure(); - - rewriter.replaceOpWithNewOp( - extract, tensorCast.getSource(), extract.getIndices()); - return success(); - } -}; - } // namespace void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - // TODO: Move extract patterns to tensor::ExtractOp. - results.add(context); + // TODO: Move extract pattern to tensor::ExtractOp. + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -115,12 +115,12 @@ // ----- // CHECK-LABEL: func @extract_from_tensor.cast -// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32> -func.func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 { +// CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32> +func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 { // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index // CHECK-NOT: tensor.cast - %casted = tensor.cast %tensor : tensor<*xf32> to tensor + %casted = tensor.cast %tensor : tensor<9xf32> to tensor // CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]] %result = tensor.extract %casted[%c0] : tensor return %result : f32 diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -34,12 +34,9 @@ // CHECK-SAME: %[[SCALAR:.*]]: f32 // CHECK-SAME: %[[INDEX:.*]]: index // CHECK-SAME: %[[DEST1:.*]]: tensor -// CHECK-SAME: %[[DEST2:.*]]: tensor<*xf32> -func.func @insert(%arg0: f32, %arg1: index, %arg2: tensor, %arg3: tensor<*xf32>) { +func.func @insert(%arg0: f32, %arg1: index, %arg2: tensor) { // CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor %0 = tensor.insert %arg0 into %arg2[%arg1, %arg1, %arg1] : tensor - // CHECK: tensor.insert %[[SCALAR]] into %[[DEST2]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<*xf32> - %1 = tensor.insert %arg0 into %arg3[%arg1, %arg1, %arg1] : tensor<*xf32> return }