diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -128,6 +128,7 @@ DenseIntElementsAttr getI64TensorAttr(ArrayRef values); ArrayAttr getAffineMapArrayAttr(ArrayRef values); + ArrayAttr getBoolArrayAttr(ArrayRef values); ArrayAttr getI32ArrayAttr(ArrayRef values); ArrayAttr getI64ArrayAttr(ArrayRef values); ArrayAttr getIndexArrayAttr(ArrayRef values); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1346,6 +1346,10 @@ Attr elementAttr = element; } +def BoolArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getBoolArrayAttr($0)"; +} def I32ArrayAttr : TypedArrayAttrBase { let constBuilderCall = "$_builder.getI32ArrayAttr($0)"; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -202,12 +202,17 @@ return SymbolRefAttr::get(value, nestedReferences, getContext()); } +ArrayAttr Builder::getBoolArrayAttr(ArrayRef values) { + auto attrs = llvm::to_vector<8>(llvm::map_range( + values, [this](bool v) -> Attribute { return getBoolAttr(v); })); + return getArrayAttr(attrs); +} + ArrayAttr Builder::getI32ArrayAttr(ArrayRef values) { auto attrs = llvm::to_vector<8>(llvm::map_range( values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); })); return getArrayAttr(attrs); } - ArrayAttr Builder::getI64ArrayAttr(ArrayRef values) { auto attrs = llvm::to_vector<8>(llvm::map_range( values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));