diff --git a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h @@ -20,9 +20,11 @@ using vector_fma = ValueBuilder; using vector_extract = ValueBuilder; using vector_matmul = ValueBuilder; +using vector_outerproduct = ValueBuilder; using vector_print = OperationBuilder; using vector_transfer_read = ValueBuilder; using vector_transfer_write = OperationBuilder; +using vector_transpose = ValueBuilder; using vector_type_cast = ValueBuilder; using vector_insert = ValueBuilder; using vector_fma = ValueBuilder; diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1385,6 +1385,9 @@ [c, f] ] ``` }]; + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value vector, " + "ArrayRef transp">]; let extraClassDeclaration = [{ VectorType getVectorType() { return vector().getType().cast(); @@ -1393,6 +1396,7 @@ return result().getType().cast(); } void getTransp(SmallVectorImpl &results); + static StringRef getTranspAttrName() { return "transp"; } }]; let assemblyFormat = [{ $vector `,` $transp attr-dict `:` type($vector) `to` type($result) diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -358,8 +358,23 @@ /// Emits a `load` when converting to a Value. operator Value() const { return Load(value, indices); } + /// Returns the base memref. Value getBase() const { return value; } + /// Returns the underlying memref. + MemRefType getMemRefType() const { + return value.getType().template cast(); + } + + /// Returns the underlying MemRef elemental type cast as `T`. + template + T getElementalTypeAs() const { + return value.getType() + .template cast() + .getElementType() + .template cast(); + } + /// Arithmetic operator overloadings. Value operator+(Value e); Value operator-(Value e); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1713,6 +1713,18 @@ // TransposeOp //===----------------------------------------------------------------------===// +void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, + Value vector, ArrayRef transp) { + VectorType vt = vector.getType().cast(); + SmallVector transposedShape(vt.getRank()); + for (unsigned i = 0; i < transp.size(); ++i) + transposedShape[i] = vt.getShape()[transp[i]]; + + result.addOperands(vector); + result.addTypes(VectorType::get(transposedShape, vt.getElementType())); + result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp)); +} + // Eliminates transpose operations, which produce values identical to their // input values. This happens when the dimensions of the input vector remain in // their original order after the transpose operation.