diff --git a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h @@ -20,9 +20,11 @@ using vector_fma = ValueBuilder<vector::FMAOp>; using vector_extract = ValueBuilder<vector::ExtractOp>; using vector_matmul = ValueBuilder<vector::MatmulOp>; +using vector_outerproduct = ValueBuilder<vector::OuterProductOp>; using vector_print = OperationBuilder<vector::PrintOp>; using vector_transfer_read = ValueBuilder<vector::TransferReadOp>; using vector_transfer_write = OperationBuilder<vector::TransferWriteOp>; +using vector_transpose = ValueBuilder<vector::TransposeOp>; using vector_type_cast = ValueBuilder<vector::TypeCastOp>; using vector_insert = ValueBuilder<vector::InsertOp>; using vector_fma = ValueBuilder<vector::FMAOp>; diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1377,6 +1377,9 @@ [c, f] ] ``` }]; + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value vector, " + "ArrayRef<int64_t> permutation">]; let extraClassDeclaration = [{ VectorType getVectorType() { return vector().getType().cast<VectorType>(); @@ -1385,6 +1388,7 @@ return result().getType().cast<VectorType>(); } void getTransp(SmallVectorImpl<int64_t> &results); + static StringRef getTranspAttrName() { return "transp"; } }]; let assemblyFormat = [{ $vector `,` $transp attr-dict `:` type($vector) `to` type($result) diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -358,8 +358,20 @@ /// Emits a `load` when converting to a Value. operator Value() const { return Load(value, indices); } + /// Return the base memref. Value getBase() const { return value; } + /// Return the underlying memref. + MemRefType getMemRefType() const { + return value.getType().cast<MemRefType>(); + } + + /// Return the underlying MemRef elemental type cast as `T`. + template <typename T> + T getElementalTypeAs() const { + return value.getType().cast<MemRefType>().getElementType().cast<T>(); + } + /// Arithmetic operator overloadings. Value operator+(Value e); Value operator-(Value e); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1693,6 +1693,18 @@ // TransposeOp //===----------------------------------------------------------------------===// +void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, + Value vector, ArrayRef<int64_t> transp) { + VectorType vt = vector.getType().cast<VectorType>(); + SmallVector<int64_t, 4> transposedShape(vt.getRank()); + for (unsigned i = 0; i < transp.size(); ++i) + transposedShape[i] = vt.getShape()[transp[i]]; + + result.addOperands(vector); + result.addTypes(VectorType::get(transposedShape, vt.getElementType())); + result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp)); +} + // Eliminates transpose operations, which produce values identical to their // input values. This happens when the dimensions of the input vector remain in // their original order after the transpose operation.