diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -3244,6 +3244,17 @@ // Build a SubTensorOp with all dynamic entries. OpBuilderDAG<(ins "Value":$source, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, + CArg<"ArrayRef", "{}">:$attrs)>, + // Build a SubTensorOp with mixed static and dynamic entries + // and custom result type. + OpBuilderDAG<(ins "RankedTensorType":$resultType, "Value":$source, + "ArrayRef":$staticOffsets, "ArrayRef":$staticSizes, + "ArrayRef":$staticStrides, "ValueRange":$offsets, + "ValueRange":$sizes, "ValueRange":$strides, + CArg<"ArrayRef", "{}">:$attrs)>, + // Build a SubTensorOp with all dynamic entries and custom result type. + OpBuilderDAG<(ins "RankedTensorType":$resultType, "Value":$source, + "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, CArg<"ArrayRef", "{}">:$attrs)> ]; @@ -3349,7 +3360,7 @@ return source().getType().cast(); } - /// The result of a subtensor is always a tensor. + /// The result of a subtensor_insert is always a tensor. RankedTensorType getType() { return getResult().getType().cast(); } @@ -3357,7 +3368,7 @@ /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. std::array getArrayAttrRanks() { - unsigned rank = getSourceType().getRank(); + unsigned rank = getType().getRank(); return {rank, rank, rank}; } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3444,6 +3444,39 @@ staticStridesVector, offsets, sizes, strides, attrs); } +/// Build a SubTensorOp as above but with custom result type. +void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result, + RankedTensorType resultType, Value source, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides, + ValueRange offsets, ValueRange sizes, + ValueRange strides, + ArrayRef attrs) { + build(b, result, resultType, source, offsets, sizes, strides, + b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes), + b.getI64ArrayAttr(staticStrides)); + result.addAttributes(attrs); +} + +/// Build a SubTensorOp as above but with custom result type. +void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result, + RankedTensorType resultType, Value source, + ValueRange offsets, ValueRange sizes, + ValueRange strides, + ArrayRef attrs) { + auto sourceRankedTensorType = source.getType().cast(); + unsigned rank = sourceRankedTensorType.getRank(); + SmallVector staticOffsetsVector; + staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset); + SmallVector staticSizesVector; + staticSizesVector.assign(rank, ShapedType::kDynamicSize); + SmallVector staticStridesVector; + staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset); + build(b, result, resultType, source, staticOffsetsVector, staticSizesVector, + staticStridesVector, offsets, sizes, strides, attrs); +} + /// Verifier for SubTensorOp. static LogicalResult verify(SubTensorOp op) { // Verify result type against inferred type. @@ -3528,8 +3561,8 @@ ValueRange offsets, ValueRange sizes, ValueRange strides, ArrayRef attrs) { - auto sourceRankedTensorType = source.getType().cast(); - unsigned rank = sourceRankedTensorType.getRank(); + auto destRankedTensorType = dest.getType().cast(); + unsigned rank = destRankedTensorType.getRank(); SmallVector staticOffsetsVector( rank, ShapedType::kDynamicStrideOrOffset); SmallVector staticSizesVector(rank, ShapedType::kDynamicSize); diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -908,7 +908,7 @@ } // CHECK-LABEL: func @subtensor_insert({{.*}}) { -func @subtensor_insert(%t: tensor<8x16x4xf32>, %t2: tensor<16x32x8xf32>, %idx : index) { +func @subtensor_insert(%t: tensor<8x16x4xf32>, %t2: tensor<16x32x8xf32>, %t3: tensor<4x4xf32>, %idx : index) { %c0 = constant 0 : index %c1 = constant 1 : index @@ -922,5 +922,10 @@ %2 = subtensor_insert %t into %t2[%c0, %idx, %c0][%idx, 4, %idx][%c1, 1, %c1] : tensor<8x16x4xf32> into tensor<16x32x8xf32> + // CHECK: subtensor_insert + // CHECK-SAME: tensor<4x4xf32> into tensor<8x16x4xf32> + %3 = subtensor_insert %t3 into %t[0, 2, 0][4, 1, 4][1, 1, 1] + : tensor<4x4xf32> into tensor<8x16x4xf32> + return }