diff --git a/flang/lib/Lower/RTBuilder.h b/flang/lib/Lower/RTBuilder.h --- a/flang/lib/Lower/RTBuilder.h +++ b/flang/lib/Lower/RTBuilder.h @@ -168,7 +168,7 @@ return [](mlir::MLIRContext *context) -> mlir::Type { // FIXME: a namelist group must be some well-defined data structure, use a // tuple as a proxy for the moment - return mlir::TupleType::get(llvm::None, context); + return mlir::TupleType::get(context); }; } template <> diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -73,8 +73,8 @@ IntegerType getI64Type(); IntegerType getIntegerType(unsigned width); IntegerType getIntegerType(unsigned width, bool isSigned); - FunctionType getFunctionType(ArrayRef inputs, ArrayRef results); - TupleType getTupleType(ArrayRef elementTypes); + FunctionType getFunctionType(TypeRange inputs, TypeRange results); + TupleType getTupleType(TypeRange elementTypes); NoneType getNoneType(); /// Get or construct an instance of the type 'ty' with provided arguments. diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -17,6 +17,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/Location.h" +#include "mlir/IR/TypeRange.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/Support/InterfaceSupport.h" @@ -624,104 +625,6 @@ // Operation Value-Iterators //===----------------------------------------------------------------------===// -//===----------------------------------------------------------------------===// -// TypeRange - -/// This class provides an abstraction over the various different ranges of -/// value types. In many cases, this prevents the need to explicitly materialize -/// a SmallVector/std::vector. This class should be used in places that are not -/// suitable for a more derived type (e.g. ArrayRef) or a template range -/// parameter. -class TypeRange - : public llvm::detail::indexed_accessor_range_base< - TypeRange, - llvm::PointerUnion, Type, - Type, Type> { -public: - using RangeBaseT::RangeBaseT; - TypeRange(ArrayRef types = llvm::None); - explicit TypeRange(OperandRange values); - explicit TypeRange(ResultRange values); - explicit TypeRange(ValueRange values); - explicit TypeRange(ArrayRef values); - explicit TypeRange(ArrayRef values) - : TypeRange(ArrayRef(values.data(), values.size())) {} - template - TypeRange(ValueTypeRange values) - : TypeRange(ValueRangeT(values.begin().getCurrent(), - values.end().getCurrent())) {} - template , Arg>::value>> - TypeRange(Arg &&arg) : TypeRange(ArrayRef(std::forward(arg))) {} - TypeRange(std::initializer_list types) - : TypeRange(ArrayRef(types)) {} - -private: - /// The owner of the range is either: - /// * A pointer to the first element of an array of values. - /// * A pointer to the first element of an array of types. - /// * A pointer to the first element of an array of operands. - using OwnerT = llvm::PointerUnion; - - /// See `llvm::detail::indexed_accessor_range_base` for details. - static OwnerT offset_base(OwnerT object, ptrdiff_t index); - /// See `llvm::detail::indexed_accessor_range_base` for details. - static Type dereference_iterator(OwnerT object, ptrdiff_t index); - - /// Allow access to `offset_base` and `dereference_iterator`. - friend RangeBaseT; -}; - -//===----------------------------------------------------------------------===// -// ValueTypeRange - -/// This class implements iteration on the types of a given range of values. -template -class ValueTypeIterator final - : public llvm::mapped_iterator { - static Type unwrap(Value value) { return value.getType(); } - -public: - using reference = Type; - - /// Provide a const dereference method. - Type operator*() const { return unwrap(*this->I); } - - /// Initializes the type iterator to the specified value iterator. - ValueTypeIterator(ValueIteratorT it) - : llvm::mapped_iterator(it, &unwrap) {} -}; - -/// This class implements iteration on the types of a given range of values. -template -class ValueTypeRange final - : public llvm::iterator_range< - ValueTypeIterator> { -public: - using llvm::iterator_range< - ValueTypeIterator>::iterator_range; - template - ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {} - - /// Compare this range with another. - template - bool operator==(const OtherT &other) const { - return llvm::size(*this) == llvm::size(other) && - std::equal(this->begin(), this->end(), other.begin()); - } - template - bool operator!=(const OtherT &other) const { - return !(*this == other); - } -}; - -template -inline bool operator==(ArrayRef lhs, const ValueTypeRange &rhs) { - return lhs.size() == static_cast(llvm::size(rhs)) && - std::equal(lhs.begin(), lhs.end(), rhs.begin()); -} - //===----------------------------------------------------------------------===// // OperandRange diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -632,10 +632,10 @@ /// Get or create a new TupleType with the provided element types. Assumes the /// arguments define a well-formed type. - static TupleType get(ArrayRef elementTypes, MLIRContext *context); + static TupleType get(TypeRange elementTypes, MLIRContext *context); /// Get or create an empty tuple type. - static TupleType get(MLIRContext *context) { return get({}, context); } + static TupleType get(MLIRContext *context); /// Return the elements types for this tuple. ArrayRef getTypes() const; diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/TypeRange.h @@ -0,0 +1,181 @@ +//===- TypeRange.h ----------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the TypeRange and ValueTypeRange classes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_TYPERANGE_H +#define MLIR_IR_TYPERANGE_H + +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/PointerUnion.h" + +namespace mlir { +class OperandRange; +class ResultRange; +class Type; +class Value; +class ValueRange; +template +class ValueTypeRange; + +//===----------------------------------------------------------------------===// +// TypeRange + +/// This class provides an abstraction over the various different ranges of +/// value types. In many cases, this prevents the need to explicitly materialize +/// a SmallVector/std::vector. This class should be used in places that are not +/// suitable for a more derived type (e.g. ArrayRef) or a template range +/// parameter. +class TypeRange + : public llvm::detail::indexed_accessor_range_base< + TypeRange, + llvm::PointerUnion, Type, + Type, Type> { +public: + using RangeBaseT::RangeBaseT; + TypeRange(ArrayRef types = llvm::None); + explicit TypeRange(OperandRange values); + explicit TypeRange(ResultRange values); + explicit TypeRange(ValueRange values); + explicit TypeRange(ArrayRef values); + explicit TypeRange(ArrayRef values) + : TypeRange(ArrayRef(values.data(), values.size())) {} + template + TypeRange(ValueTypeRange values) + : TypeRange(ValueRangeT(values.begin().getCurrent(), + values.end().getCurrent())) {} + template , Arg>::value>> + TypeRange(Arg &&arg) : TypeRange(ArrayRef(std::forward(arg))) {} + TypeRange(std::initializer_list types) + : TypeRange(ArrayRef(types)) {} + +private: + /// The owner of the range is either: + /// * A pointer to the first element of an array of values. + /// * A pointer to the first element of an array of types. + /// * A pointer to the first element of an array of operands. + using OwnerT = llvm::PointerUnion; + + /// See `llvm::detail::indexed_accessor_range_base` for details. + static OwnerT offset_base(OwnerT object, ptrdiff_t index); + /// See `llvm::detail::indexed_accessor_range_base` for details. + static Type dereference_iterator(OwnerT object, ptrdiff_t index); + + /// Allow access to `offset_base` and `dereference_iterator`. + friend RangeBaseT; +}; + +/// Make TypeRange hashable. +inline ::llvm::hash_code hash_value(TypeRange arg) { + return ::llvm::hash_combine_range(arg.begin(), arg.end()); +} + +//===----------------------------------------------------------------------===// +// ValueTypeRange + +/// This class implements iteration on the types of a given range of values. +template +class ValueTypeIterator final + : public llvm::mapped_iterator { + static Type unwrap(Value value) { return value.getType(); } + +public: + using reference = Type; + + /// Provide a const dereference method. + Type operator*() const { return unwrap(*this->I); } + + /// Initializes the type iterator to the specified value iterator. + ValueTypeIterator(ValueIteratorT it) + : llvm::mapped_iterator(it, &unwrap) {} +}; + +/// This class implements iteration on the types of a given range of values. +template +class ValueTypeRange final + : public llvm::iterator_range< + ValueTypeIterator> { +public: + using llvm::iterator_range< + ValueTypeIterator>::iterator_range; + template + ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {} + + /// Compare this range with another. + template + bool operator==(const OtherT &other) const { + return llvm::size(*this) == llvm::size(other) && + std::equal(this->begin(), this->end(), other.begin()); + } + template + bool operator!=(const OtherT &other) const { + return !(*this == other); + } +}; + +template +inline bool operator==(ArrayRef lhs, const ValueTypeRange &rhs) { + return lhs.size() == static_cast(llvm::size(rhs)) && + std::equal(lhs.begin(), lhs.end(), rhs.begin()); +} + +} // namespace mlir + +namespace llvm { + +// Provide DenseMapInfo for TypeRange. +template <> +struct DenseMapInfo { + static mlir::TypeRange getEmptyKey() { + return mlir::TypeRange(getEmptyKeyPointer(), 0); + } + + static mlir::TypeRange getTombstoneKey() { + return mlir::TypeRange(getTombstoneKeyPointer(), 0); + } + + static unsigned getHashValue(mlir::TypeRange val) { return hash_value(val); } + + static bool isEqual(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (isEmptyKey(rhs)) + return isEmptyKey(lhs); + if (isTombstoneKey(rhs)) + return isTombstoneKey(lhs); + return lhs == rhs; + } + +private: + static const mlir::Type *getEmptyKeyPointer() { + return DenseMapInfo::getEmptyKey(); + } + + static const mlir::Type *getTombstoneKeyPointer() { + return DenseMapInfo::getTombstoneKey(); + } + + static bool isEmptyKey(mlir::TypeRange range) { + if (const auto *type = range.getBase().dyn_cast()) + return type == getEmptyKeyPointer(); + return false; + } + + static bool isTombstoneKey(mlir::TypeRange range) { + if (const auto *type = range.getBase().dyn_cast()) + return type == getTombstoneKeyPointer(); + return false; + } +}; + +} // namespace llvm + +#endif // MLIR_IR_TYPERANGE_H diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -21,6 +21,7 @@ class IntegerType; class MLIRContext; class TypeStorage; +class TypeRange; namespace detail { struct FunctionTypeStorage; @@ -259,21 +260,17 @@ public: using Base::Base; - static FunctionType get(ArrayRef inputs, ArrayRef results, + static FunctionType get(TypeRange inputs, TypeRange results, MLIRContext *context); // Input types. unsigned getNumInputs() const { return getSubclassData(); } - Type getInput(unsigned i) const { return getInputs()[i]; } - ArrayRef getInputs() const; // Result types. unsigned getNumResults() const; - Type getResult(unsigned i) const { return getResults()[i]; } - ArrayRef getResults() const; /// Methods for support type inquiry through isa, cast, and dyn_cast. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -549,15 +549,13 @@ else p << op.getOperand(0); - p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; + auto args = op.getOperands().drop_front(isDirect ? 0 : 1); + p << '(' << args << ')'; p.printOptionalAttrDict(op.getAttrs(), {"callee"}); // Reconstruct the function MLIR function type from operand and result types. - SmallVector argTypes( - llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); - p << " : " - << FunctionType::get(argTypes, op.getResultTypes(), op.getContext()); + << FunctionType::get(args.getTypes(), op.getResultTypes(), op.getContext()); } // ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -749,8 +749,7 @@ } FunctionType CallOp::getCalleeType() { - SmallVector argTypes(getOperandTypes()); - return FunctionType::get(argTypes, getResultTypes(), getContext()); + return FunctionType::get(getOperandTypes(), getResultTypes(), getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -67,12 +67,11 @@ width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context); } -FunctionType Builder::getFunctionType(ArrayRef inputs, - ArrayRef results) { +FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) { return FunctionType::get(inputs, results, context); } -TupleType Builder::getTupleType(ArrayRef elementTypes) { +TupleType Builder::getTupleType(TypeRange elementTypes) { return TupleType::get(elementTypes, context); } diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -22,6 +22,7 @@ StandardTypes.cpp SymbolTable.cpp Types.cpp + TypeRange.cpp TypeUtilities.cpp Value.cpp Verifier.cpp diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -360,45 +360,6 @@ // Operation Value-Iterators //===----------------------------------------------------------------------===// -//===----------------------------------------------------------------------===// -// TypeRange - -TypeRange::TypeRange(ArrayRef types) - : TypeRange(types.data(), types.size()) {} -TypeRange::TypeRange(OperandRange values) - : TypeRange(values.begin().getBase(), values.size()) {} -TypeRange::TypeRange(ResultRange values) - : TypeRange(values.getBase()->getResultTypes().slice(values.getStartIndex(), - values.size())) {} -TypeRange::TypeRange(ArrayRef values) - : TypeRange(values.data(), values.size()) {} -TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) { - detail::ValueRangeOwner owner = values.begin().getBase(); - if (auto *op = reinterpret_cast(owner.ptr.dyn_cast())) - this->base = op->getResultTypes().drop_front(owner.startIndex).data(); - else if (auto *operand = owner.ptr.dyn_cast()) - this->base = operand; - else - this->base = owner.ptr.get(); -} - -/// See `llvm::detail::indexed_accessor_range_base` for details. -TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) { - if (auto *value = object.dyn_cast()) - return {value + index}; - if (auto *operand = object.dyn_cast()) - return {operand + index}; - return {object.dyn_cast() + index}; -} -/// See `llvm::detail::indexed_accessor_range_base` for details. -Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) { - if (auto *value = object.dyn_cast()) - return (value + index)->getType(); - if (auto *operand = object.dyn_cast()) - return (operand + index)->get().getType(); - return object.dyn_cast()[index]; -} - //===----------------------------------------------------------------------===// // OperandRange diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -638,10 +638,13 @@ /// Get or create a new TupleType with the provided element types. Assumes the /// arguments define a well-formed type. -TupleType TupleType::get(ArrayRef elementTypes, MLIRContext *context) { +TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) { return Base::get(context, StandardTypes::Tuple, elementTypes); } +/// Get or create an empty tuple type. +TupleType TupleType::get(MLIRContext *context) { return get({}, context); } + /// Return the elements types for this tuple. ArrayRef TupleType::getTypes() const { return getImpl()->getTypes(); } diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -15,7 +15,9 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeRange.h" #include "llvm/ADT/bit.h" #include "llvm/Support/TrailingObjects.h" @@ -105,7 +107,7 @@ inputsAndResults(inputsAndResults) {} /// The hash key used for uniquing. - using KeyTy = std::pair, ArrayRef>; + using KeyTy = std::pair; bool operator==(const KeyTy &key) const { return key == KeyTy(getInputs(), getResults()); } @@ -113,7 +115,7 @@ /// Construction. static FunctionTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { - ArrayRef inputs = key.first, results = key.second; + TypeRange inputs = key.first, results = key.second; // Copy the inputs and results into the bump pointer. SmallVector types; @@ -320,13 +322,13 @@ struct TupleTypeStorage final : public TypeStorage, public llvm::TrailingObjects { - using KeyTy = ArrayRef; + using KeyTy = TypeRange; TupleTypeStorage(unsigned numTypes) : TypeStorage(numTypes) {} /// Construction. static TupleTypeStorage *construct(TypeStorageAllocator &allocator, - ArrayRef key) { + TypeRange key) { // Allocate a new storage instance. auto byteSize = TupleTypeStorage::totalSizeToAlloc(key.size()); auto rawMem = allocator.allocate(byteSize, alignof(TupleTypeStorage)); diff --git a/mlir/lib/IR/TypeRange.cpp b/mlir/lib/IR/TypeRange.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/TypeRange.cpp @@ -0,0 +1,50 @@ +//===- TypeRange.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/Operation.h" +using namespace mlir; + +//===----------------------------------------------------------------------===// +// TypeRange + +TypeRange::TypeRange(ArrayRef types) + : TypeRange(types.data(), types.size()) {} +TypeRange::TypeRange(OperandRange values) + : TypeRange(values.begin().getBase(), values.size()) {} +TypeRange::TypeRange(ResultRange values) + : TypeRange(values.getBase()->getResultTypes().slice(values.getStartIndex(), + values.size())) {} +TypeRange::TypeRange(ArrayRef values) + : TypeRange(values.data(), values.size()) {} +TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) { + detail::ValueRangeOwner owner = values.begin().getBase(); + if (auto *op = reinterpret_cast(owner.ptr.dyn_cast())) + this->base = op->getResultTypes().drop_front(owner.startIndex).data(); + else if (auto *operand = owner.ptr.dyn_cast()) + this->base = operand; + else + this->base = owner.ptr.get(); +} + +/// See `llvm::detail::indexed_accessor_range_base` for details. +TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) { + if (const auto *value = object.dyn_cast()) + return {value + index}; + if (auto *operand = object.dyn_cast()) + return {operand + index}; + return {object.dyn_cast() + index}; +} +/// See `llvm::detail::indexed_accessor_range_base` for details. +Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) { + if (const auto *value = object.dyn_cast()) + return (value + index)->getType(); + if (auto *operand = object.dyn_cast()) + return (operand + index)->get().getType(); + return object.dyn_cast()[index]; +} diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -34,7 +34,7 @@ // FunctionType //===----------------------------------------------------------------------===// -FunctionType FunctionType::get(ArrayRef inputs, ArrayRef results, +FunctionType FunctionType::get(TypeRange inputs, TypeRange results, MLIRContext *context) { return Base::get(context, Type::Kind::Function, inputs, results); }