diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -321,11 +321,14 @@ /// parameter. class RegionRange : public llvm::detail::indexed_accessor_range_base< - RegionRange, PointerUnion *>, + RegionRange, + PointerUnion *, Region **>, Region *, Region *, Region *> { - /// The type representing the owner of this range. This is either a list of - /// values, operands, or results. - using OwnerT = PointerUnion *>; + /// The type representing the owner of this range. This is either an owning + /// list of regions, a list of region unique pointers, or a list of region + /// pointers. + using OwnerT = + PointerUnion *, Region **>; public: using RangeBaseT::RangeBaseT; @@ -339,6 +342,7 @@ : RegionRange(ArrayRef>(std::forward(arg))) { } RegionRange(ArrayRef> regions); + RegionRange(ArrayRef regions); private: /// See `llvm::detail::indexed_accessor_range_base` for details. diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -228,18 +228,24 @@ : RegionRange(regions.data(), regions.size()) {} RegionRange::RegionRange(ArrayRef> regions) : RegionRange(regions.data(), regions.size()) {} +RegionRange::RegionRange(ArrayRef regions) + : RegionRange(const_cast(regions.data()), regions.size()) {} /// See `llvm::detail::indexed_accessor_range_base` for details. RegionRange::OwnerT RegionRange::offset_base(const OwnerT &owner, ptrdiff_t index) { - if (auto *operand = owner.dyn_cast *>()) - return operand + index; + if (auto *region = owner.dyn_cast *>()) + return region + index; + if (auto **region = owner.dyn_cast()) + return region + index; return &owner.get()[index]; } /// See `llvm::detail::indexed_accessor_range_base` for details. Region *RegionRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) { - if (auto *operand = owner.dyn_cast *>()) - return operand[index].get(); + if (auto *region = owner.dyn_cast *>()) + return region[index].get(); + if (auto **region = owner.dyn_cast()) + return region[index]; return &owner.get()[index]; }