diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -941,6 +941,53 @@ }]; } +def Vector_CastSlicesOp : + Vector_Op<"cast_slices", [NoSideEffect]>, + Arguments<(ins AnyVector:$vector, I64ArrayAttr:$source_sizes, + I64ArrayAttr:$result_sizes)>, + Results<(outs AnyVector)> { + let summary = "The cast_slices op casts slices of a vector to another type"; + let description = [{ + The cast_slices vector operation type casts slices from a 'source' vector + to slices of a 'result' vector, where there is a one-to-one mapping + between source vector and result vector slices (specified by the + 'source_sizes' and 'result_sizes' arguments). + + This operation has the following requirements: + *) The number of slices generated by applying the target slice sizes + 'source_sizes' to the source vector shape, must equal the number of + slices generated by applying the target slice sizes 'result_sizes' to + the operation's result vector shape. + *) The cast between source/result slice shapes must be trivial (i.e. it + must not move or reshape data). + + Example: + + ```mlir + // Cast slices of 'vector<4x1x3x12xf32>' from shape 1x1x3x12 to 3x12 + %1 = vector.cast_slices %0, [1, 1, 3, 12], [3, 12] + : vector<4x1x3x12xf32> to vector<12x12xf32> + + // Cast slices of 'vector<12x12xf32>' from shape 3x12 to 1x1x3x12 + %2 = vector.cast_slices %1, [3, 12], [1, 1, 3, 12] + : vector<12x12xf32> to vector<4x1x3x12xf32> + ``` + }]; + + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return vector().getType().cast(); + } + VectorType getResultVectorType() { + return getResult().getType().cast(); + } + void getSourceSizes(SmallVectorImpl &results); + void getResultSizes(SmallVectorImpl &results); + static StringRef getSourceSizesAttrName() { return "source_sizes"; } + static StringRef getResultSizesAttrName() { return "result_sizes"; } + }]; +} + def Vector_TypeCastOp : Vector_Op<"type_cast", [NoSideEffect]>, Arguments<(ins StaticShapeMemRefOf<[AnyType]>:$memref)>, diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -1587,6 +1587,89 @@ [&op](Twine t) { return op.emitOpError(t); }); } +//===----------------------------------------------------------------------===// +// CastSlicesOp +//===----------------------------------------------------------------------===// + +ParseResult parseCastSlicesOp(OpAsmParser &parser, OperationState &result) { + ArrayAttr sourceSizesAttr; + StringRef sourceSizesAttrName = CastSlicesOp::getSourceSizesAttrName(); + ArrayAttr resultSizesAttr; + StringRef resultSizesAttrName = CastSlicesOp::getResultSizesAttrName(); + + OpAsmParser::OperandType source; + Type sourceType; + VectorType vectorType; + return failure( + parser.parseOperand(source) || parser.parseComma() || + parser.parseAttribute(sourceSizesAttr, sourceSizesAttrName, + result.attributes) || parser.parseComma() || + parser.parseAttribute(resultSizesAttr, resultSizesAttrName, + result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(sourceType) || + parser.parseKeywordType("to", vectorType) || + parser.resolveOperand(source, sourceType, result.operands) || + parser.addTypeToList(vectorType, result.types)); +} + +static void print(OpAsmPrinter &p, CastSlicesOp op) { + p << op.getOperationName() << ' ' << op.vector() << ", "; + p << op.source_sizes() << ", " << op.result_sizes(); + p.printOptionalAttrDict( + op.getAttrs(), + /*elidedAttrs=*/{CastSlicesOp::getSourceSizesAttrName(), + CastSlicesOp::getResultSizesAttrName()}); + p << " : " << op.getSourceVectorType() << " to " + << op.getResultVectorType(); +} + +static SmallVector canonicalizeShape(ArrayRef shape) { + SmallVector result; + for (int64_t dim : shape) + if (dim != 1) + result.push_back(dim); + return result; +} + +static LogicalResult verify(CastSlicesOp op) { + // Verify that slicing source vector shape by 'source_sizes' produces the + // same number of slices as slicing result vector shape by 'result_sizes'. + SmallVector sourceSizes; + op.getSourceSizes(sourceSizes); + auto sourceVectorType = op.getSourceVectorType(); + auto sourceVectorShape = sourceVectorType.getShape(); + int64_t sourceSliceCount = 1; + for (unsigned i = 0; i < sourceVectorType.getRank(); ++i) + sourceSliceCount *= ceilDiv(sourceVectorShape[i], sourceSizes[i]); + + SmallVector resultSizes; + op.getResultSizes(resultSizes); + auto resultVectorType = op.getResultVectorType(); + auto resultVectorShape = resultVectorType.getShape(); + int64_t resultSliceCount = 1; + for (unsigned i = 0; i < resultVectorType.getRank(); ++i) + resultSliceCount *= ceilDiv(resultVectorShape[i], resultSizes[i]); + + if (sourceSliceCount != resultSliceCount) + return op.emitOpError( + "requires the same number of source/result vector slices"); + + // Verify that the target slice shapes are trivially castable from one + // to another. + if (canonicalizeShape(sourceSizes) != canonicalizeShape(resultSizes)) + return op.emitOpError("requires source/result with equivalent shapes"); + return success(); +} + +void CastSlicesOp::getSourceSizes(SmallVectorImpl &results) { + populateFromInt64AttrArray(source_sizes(), results); +} + +void CastSlicesOp::getResultSizes(SmallVectorImpl &results) { + populateFromInt64AttrArray(result_sizes(), results); +} + //===----------------------------------------------------------------------===// // TypeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -889,3 +889,19 @@ %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4] : vector<3x2x4xf32> to vector<2x3x5xf32> } + +// ----- + +func @cast_slices_wrong_number_of_slices(%arg0 : vector<5x1x3x12xf32>) { + // expected-error@+1 {{requires the same number of source/result vector slices}} + %0 = vector.cast_slices %arg0, [1, 1, 3, 12], [3, 12] + : vector<5x1x3x12xf32> to vector<12x12xf32> +} + +// ----- + +func @cast_slices_cant_cast_slice_sizes(%arg0 : vector<4x1x3x12xf32>) { + // expected-error@+1 {{requires source/result with equivalent shapes}} + %0 = vector.cast_slices %arg0, [1, 2, 3, 12], [3, 12] + : vector<4x1x3x12xf32> to vector<12x12xf32> +} \ No newline at end of file diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -233,3 +233,15 @@ return %1 : vector<2x3x4xf32> } + +// CHECK-LABEL: cast_slices +func @cast_slices(%arg0 : vector<4x1x3x12xf32>) -> (vector<4x1x3x12xf32>) { + // CHECK: %[[CS0:.*]] = vector.cast_slices %{{.*}}, [1, 1, 3, 12], [3, 12] : vector<4x1x3x12xf32> to vector<12x12xf32> + %0 = vector.cast_slices %arg0, [1, 1, 3, 12], [3, 12] + : vector<4x1x3x12xf32> to vector<12x12xf32> + // CHECK: vector.cast_slices %[[CS0]], [3, 12], [1, 1, 3, 12] : vector<12x12xf32> to vector<4x1x3x12xf32> + %1 = vector.cast_slices %0, [3, 12], [1, 1, 3, 12] + : vector<12x12xf32> to vector<4x1x3x12xf32> + + return %1 : vector<4x1x3x12xf32> +} \ No newline at end of file