diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -286,10 +286,11 @@ return *this; } - /// Create a new RankedTensorType by erasing a dim from shape. - RankedTensorType dropDim(unsigned dim) { + /// Create a new RankedTensor by erasing a dim from shape @pos. + RankedTensorType dropDim(unsigned pos) { + assert(pos < shape.size() && "overflow"); SmallVector newShape(shape.begin(), shape.end()); - newShape.erase(newShape.begin() + dim); + newShape.erase(newShape.begin() + pos); return setShape(newShape); } @@ -303,6 +304,52 @@ Attribute encoding; }; +//===----------------------------------------------------------------------===// +// VectorType +//===----------------------------------------------------------------------===// + +/// This is a builder type that keeps local references to arguments. Arguments +/// that are passed into the builder must outlive the builder. +class VectorType::Builder { +public: + /// Build from another VectorType. + explicit Builder(VectorType other) + : shape(other.getShape()), elementType(other.getElementType()) {} + + /// Build from scratch. + Builder(ArrayRef shape, Type elementType) + : shape(shape), elementType(elementType) {} + + Builder &setShape(ArrayRef newShape) { + shape = newShape; + return *this; + } + + Builder &setElementType(Type newElementType) { + elementType = newElementType; + return *this; + } + + /// Create a new VectorType by erasing a dim from shape @pos. + /// In the particular case where the vector has a single dimension that we + /// drop, return the scalar element type. + // TODO: unify once we have a VectorType that supports 0-D. + Type dropDim(unsigned pos) { + assert(pos < shape.size() && "overflow"); + if (shape.size() == 1) + return elementType; + SmallVector newShape(shape.begin(), shape.end()); + newShape.erase(newShape.begin() + pos); + return setShape(newShape); + } + + operator VectorType() { return VectorType::get(shape, elementType); } + +private: + ArrayRef shape; + Type elementType; +}; + /// Given an `originalShape` and a `reducedShape` assumed to be a subset of /// `originalShape` with some `1` entries erased, return the set of indices /// that specifies which of the entries of `originalShape` are dropped to obtain diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -929,6 +929,10 @@ }]> ]; let extraClassDeclaration = [{ + /// This is a builder type that keeps local references to arguments. + /// Arguments that are passed into the builder must outlive the builder. + class Builder; + /// Returns true of the given type can be used as an element of a vector /// type. In particular, vectors can consist of integer, index, or float /// primitives. diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -472,8 +472,8 @@ auto rhsType = types[1].cast(); auto maskElementType = parser.getBuilder().getI1Type(); std::array maskTypes = { - VectorType::get(lhsType.getShape(), maskElementType), - VectorType::get(rhsType.getShape(), maskElementType)}; + VectorType::Builder(lhsType).setElementType(maskElementType), + VectorType::Builder(rhsType).setElementType(maskElementType)}; if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands)) return failure(); return success(); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -79,25 +79,6 @@ return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); } -// Helper to drop dimension from vector type. -static Type adjustType(VectorType tp, int64_t index) { - int64_t rank = tp.getRank(); - Type eltType = tp.getElementType(); - if (rank == 1) { - assert(index == 0 && "index for scalar result out of bounds"); - return eltType; - } - SmallVector adjustedShape; - for (int64_t i = 0; i < rank; ++i) { - // Omit dimension at the given index. - if (i == index) - continue; - // Otherwise, add dimension back. - adjustedShape.push_back(tp.getDimSize(i)); - } - return VectorType::get(adjustedShape, eltType); -} - // Helper method to possibly drop a dimension in a load. // TODO static Value reshapeLoad(Location loc, Value val, VectorType type, @@ -105,7 +86,7 @@ PatternRewriter &rewriter) { if (index == -1) return val; - Type lowType = adjustType(type, 0); + Type lowType = VectorType::Builder(type).dropDim(0); // At extraction dimension? if (index == 0) { auto posAttr = rewriter.getI64ArrayAttr(pos); @@ -113,7 +94,7 @@ } // Unroll leading dimensions. VectorType vType = lowType.cast(); - VectorType resType = adjustType(type, index).cast(); + auto resType = VectorType::Builder(type).dropDim(index).cast(); Value result = rewriter.create( loc, resType, rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { @@ -140,9 +121,9 @@ return rewriter.create(loc, type, val, result, posAttr); } // Unroll leading dimensions. - Type lowType = adjustType(type, 0); + Type lowType = VectorType::Builder(type).dropDim(0); VectorType vType = lowType.cast(); - Type insType = adjustType(vType, 0); + Type insType = VectorType::Builder(vType).dropDim(0); for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { auto posAttr = rewriter.getI64ArrayAttr(d); Value ext = rewriter.create(loc, vType, result, posAttr);