diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -963,6 +963,49 @@ }]; } +def Vector_ShapeCastOp : + Vector_Op<"shape_cast", [NoSideEffect]>, + Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>, + Results<(outs AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$result)> { + let summary = "shape_cast casts between vector shapes"; + let description = [{ + The shape_cast operation casts between an n-D source vector shape and + a k-D result vector shape (the element type remains the same). + + If reducing rank (n > k), result dimension sizes must be a product + of contiguous source dimension sizes. + If expanding rank (n < k), source dimensions must factor into a + contiguous sequence of destination dimension sizes. + Each source dim is expanded (or contiguous sequence of source dims combined) + in source dimension list order (i.e. 0 <= i < n), to produce a contiguous + sequence of result dims (or a single result dim), in result dimension list + order (i.e. 0 <= j < k). The product of all source dimension sizes and all + result dimension sizes must match. + + If the source/result types are a tuple of vectors, the casting operation + described above is applied to each source/result tuple element pair. + + It is currently assumed that this operation does not require moving data, + and that it will be canonicalized away before lowering vector operations. + + Examples: + + ```mlir + // Example casting to a lower vector rank. + %1 = vector.shape_cast %0 : vector<5x1x4x3xf32> to vector<20x3xf32> + + // Example casting to a higher vector rank. + %3 = vector.shape_cast %2 : vector<10x12x8xf32> to vector<5x2x3x4x8xf32> + + // Example casting a tuple of vectors of same rank, where tuple elements + // may have different shapes. + %5 = vector.shape_cast %4 : tuple, vector<3x3x2xf32>> to + tuple, vector<9x2xf32>> + ``` + }]; + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; +} + def Vector_TypeCastOp : Vector_Op<"type_cast", [NoSideEffect]>, Arguments<(ins StaticShapeMemRefOf<[AnyType]>:$memref)>, diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -26,6 +26,7 @@ #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/StringSet.h" +#include using namespace mlir; using namespace mlir::vector; @@ -1389,6 +1390,108 @@ [&op](Twine t) { return op.emitOpError(t); }); } +//===----------------------------------------------------------------------===// +// ShapeCastOp +//===----------------------------------------------------------------------===// + +ParseResult parseShapeCastOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType source; + Type sourceType; + Type resultType; + return failure(parser.parseOperand(source) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(sourceType) || + parser.parseKeywordType("to", resultType) || + parser.resolveOperand(source, sourceType, result.operands) || + parser.addTypeToList(resultType, result.types)); +} + +static void print(OpAsmPrinter &p, ShapeCastOp op) { + p << op.getOperationName() << ' ' << op.source(); + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.source().getType() << " to " << op.result().getType(); +} + +/// Returns true if each element of 'a' is equal to the product of a contiguous +/// sequence of the elements of 'b'. Returns false otherwise. +static bool isValidShapeCast(ArrayRef a, ArrayRef b) { + unsigned rankA = a.size(); + unsigned rankB = b.size(); + assert(rankA < rankB); + + unsigned i = 0; + unsigned j = 0; + while (i < rankA && j < rankB) { + int64_t dimA = a[i]; + int64_t dimB = 1; + while (dimB < dimA && j < rankB) + dimB *= b[j++]; + if (dimA != dimB) + break; + ++i; + } + + return i == rankA && j == rankB; +} + +static LogicalResult verifyVectorShapeCast(Operation *op, + VectorType sourceVectorType, + VectorType resultVectorType) { + // Check that element type is the same. + if (sourceVectorType.getElementType() != resultVectorType.getElementType()) + return op->emitOpError("source/result vectors must have same element type"); + auto sourceShape = sourceVectorType.getShape(); + auto resultShape = resultVectorType.getShape(); + + // Check that product of source dim sizes matches product of result dim sizes. + int64_t sourceDimProduct = std::accumulate( + sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies{}); + int64_t resultDimProduct = std::accumulate( + resultShape.begin(), resultShape.end(), 1LL, std::multiplies{}); + if (sourceDimProduct != resultDimProduct) + return op->emitOpError("source/result number of elements must match"); + + // Check that expanding/contracting rank cases. + unsigned sourceRank = sourceVectorType.getRank(); + unsigned resultRank = resultVectorType.getRank(); + if (sourceRank < resultRank) { + if (!isValidShapeCast(sourceShape, resultShape)) + return op->emitOpError("invalid shape cast"); + } else if (sourceRank > resultRank) { + if (!isValidShapeCast(resultShape, sourceShape)) + return op->emitOpError("invalid shape cast"); + } + return success(); +} + +static LogicalResult verify(ShapeCastOp op) { + auto sourceVectorType = op.source().getType().dyn_cast_or_null(); + auto resultVectorType = op.result().getType().dyn_cast_or_null(); + + // Check if source/result are of vector type. + if (sourceVectorType && resultVectorType) + return verifyVectorShapeCast(op, sourceVectorType, resultVectorType); + + // Check if source/result are "tuple of vectors" type. + auto sourceTupleType = op.source().getType().dyn_cast_or_null(); + auto resultTupleType = op.result().getType().dyn_cast_or_null(); + if (!sourceTupleType || !resultTupleType) + return op.emitOpError("source/result must be of same type"); + + // Check that source/result tuple sizes are the same. + if (sourceTupleType.size() != resultTupleType.size()) + return op.emitOpError("source/result tuples must be the same size"); + + // Check each source/result tuple element pair. + for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i) + if (failed(verifyVectorShapeCast( + op, sourceTupleType.getType(i).cast(), + resultTupleType.getType(i).cast()))) + return failure(); + + return success(); +} + //===----------------------------------------------------------------------===// // TypeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -889,3 +889,85 @@ %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4] : vector<3x2x4xf32> to vector<2x3x5xf32> } + +// ----- + +func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) { + // expected-error@+1 {{op source/result vectors must have same element type}} + %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32> +} + +// ----- + +func @shape_cast_wrong_element_type_tuple(%arg0 : tuple, + vector<3x4x2xf32>>) { + // expected-error@+1 {{op source/result vectors must have same element type}} + %0 = vector.shape_cast %arg0 : tuple, vector<3x4x2xf32>> to + tuple, vector<12x2xi32>> +} + +// ----- + +func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) { + // expected-error@+1 {{op source/result number of elements must match}} + %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32> +} + +// ----- + +func @shape_cast_wrong_num_elements_tuple(%arg0 : tuple, + vector<3x4x2xf32>>) { + // expected-error@+1 {{op source/result number of elements must match}} + %0 = vector.shape_cast %arg0 : tuple, vector<3x4x2xf32>> to + tuple, vector<13x2xf32>> +} + +// ----- + +func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) { + // expected-error@+1 {{invalid shape cast}} + %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32> +} + +// ----- + +func @shape_cast_invalid_rank_reduction_tuple(%arg0 + : tuple, vector<3x4x2xf32>>) { + // expected-error@+1 {{invalid shape cast}} + %0 = vector.shape_cast %arg0: tuple, vector<3x4x2xf32>> to + tuple, vector<6x4xf32>> +} + +// ----- + +func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) { + // expected-error@+1 {{invalid shape cast}} + %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32> +} + +// ----- + +func @shape_cast_invalid_rank_expansion_tuple(%arg0 : tuple, + vector<12x2xf32>>) { + // expected-error@+1 {{invalid shape cast}} + %0 = vector.shape_cast %arg0 : tuple, vector<12x2xf32>> to + tuple, vector<4x3x2xf32>> +} + +// ----- + +func @shape_cast_source_result_different_types( + %arg1 : tuple, vector<12x2xf32>>) { + // expected-error@+1 {{source/result must be of same type}} + %1 = vector.shape_cast %arg1 : tuple, vector<12x2xf32>> to + vector<5x2x4xf32> +} + +// ----- + +func @shape_cast_different_tuple_sizes( + %arg1 : tuple, vector<3x4x2xf32>>) { + // expected-error@+1 {{op source/result tuples must be the same size}} + %1 = vector.shape_cast %arg1 : tuple, vector<3x4x2xf32>> to + tuple> +} diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -233,3 +233,18 @@ return %1 : vector<2x3x4xf32> } + +// CHECK-LABEL: shape_cast +func @shape_cast(%arg0 : vector<5x1x3x2xf32>, + %arg1 : tuple, vector<3x4x2xf32>>) + -> (vector<15x2xf32>, tuple, vector<12x2xf32>>) { + + // CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<15x2xf32> + %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32> + + // CHECK-NEXT: vector.shape_cast %{{.*}} : tuple, vector<3x4x2xf32>> to tuple, vector<12x2xf32>> + %1 = vector.shape_cast %arg1 : tuple, vector<3x4x2xf32>> to + tuple, vector<12x2xf32>> + + return %0, %1 : vector<15x2xf32>, tuple, vector<12x2xf32>> +}