diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -1971,18 +1971,18 @@ ```mlir %a = fir.undefined !fir.array<10x10xf32> %c = arith.constant 3.0 : f32 - %1 = fir.insert_on_range %a, %c, [0 : index, 7 : index, 0 : index, 2 : index] : (!fir.array<10x10xf32>, f32) -> !fir.array<10x10xf32> + %1 = fir.insert_on_range %a, %c from (0, 0) to (7, 2) : (!fir.array<10x10xf32>, f32) -> !fir.array<10x10xf32> ``` The first 28 elements of %1, with coordinates from (0,0) to (7,2), have the value 3.0. }]; - let arguments = (ins fir_SequenceType:$seq, AnyType:$val, ArrayAttr:$coor); + let arguments = (ins fir_SequenceType:$seq, AnyType:$val, IndexElementsAttr:$coor); let results = (outs fir_SequenceType); let assemblyFormat = [{ - $seq `,` $val `,` $coor attr-dict `:` functional-type(operands, results) + $seq `,` $val custom($coor) attr-dict `:` functional-type(operands, results) }]; let verifier = "return ::verify(*this);"; diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -929,14 +929,16 @@ return success(); } - bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const { + bool isFullRange(mlir::DenseIntElementsAttr indexes, + fir::SequenceType seqTy) const { auto extents = seqTy.getShape(); - if (indexes.size() / 2 != extents.size()) + if (indexes.size() / 2 != static_cast(extents.size())) return false; + auto cur_index = indexes.value_begin(); for (unsigned i = 0; i < indexes.size(); i += 2) { - if (indexes[i].cast().getInt() != 0) + if (*(cur_index++) != 0) return false; - if (indexes[i + 1].cast().getInt() != extents[i / 2] - 1) + if (*(cur_index++) != extents[i / 2] - 1) return false; } return true; @@ -1728,14 +1730,10 @@ SmallVector lBounds; SmallVector uBounds; - // Extract integer value from the attribute - SmallVector coordinates = llvm::to_vector<4>( - llvm::map_range(range.coor(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); - // Unzip the upper and lower bound and convert to a row major format. - for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) { + mlir::DenseIntElementsAttr coor = range.coor(); + auto reversedCoor = llvm::reverse(coor.getValues()); + for (auto i = reversedCoor.begin(), e = reversedCoor.end(); i != e; ++i) { uBounds.push_back(*i++); lBounds.push_back(*i); } diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -17,10 +17,14 @@ #include "flang/Optimizer/Support/Utils.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" @@ -1374,16 +1378,62 @@ // InsertOnRangeOp //===----------------------------------------------------------------------===// +static ParseResult +parseCustomRangeSubscript(mlir::OpAsmParser &parser, + mlir::DenseIntElementsAttr &coord) { + llvm::SmallVector lbounds; + llvm::SmallVector ubounds; + if (parser.parseKeyword("from") || + parser.parseCommaSeparatedList( + AsmParser::Delimiter::Paren, + [&] { return parser.parseInteger(lbounds.emplace_back(0)); }) || + parser.parseKeyword("to") || + parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, [&] { + return parser.parseInteger(ubounds.emplace_back(0)); + })) + return failure(); + llvm::SmallVector zippedBounds; + for (auto zip : llvm::zip(lbounds, ubounds)) { + zippedBounds.push_back(std::get<0>(zip)); + zippedBounds.push_back(std::get<1>(zip)); + } + coord = mlir::Builder(parser.getContext()).getIndexTensorAttr(zippedBounds); + return success(); +} + +void printCustomRangeSubscript(mlir::OpAsmPrinter &printer, InsertOnRangeOp op, + mlir::DenseIntElementsAttr coord) { + printer << "from ("; + auto enumerate = llvm::enumerate(coord.getValues()); + // Even entries are the lower bounds. + llvm::interleaveComma( + make_filter_range( + enumerate, + [](auto indexed_value) { return indexed_value.index() % 2 == 0; }), + printer, [&](auto indexed_value) { printer << indexed_value.value(); }); + printer << ") to ("; + // Odd entries are the upper bounds. + llvm::interleaveComma( + make_filter_range( + enumerate, + [](auto indexed_value) { return indexed_value.index() % 2 != 0; }), + printer, [&](auto indexed_value) { printer << indexed_value.value(); }); + printer << ")"; +} + /// Range bounds must be nonnegative, and the range must not be empty. static mlir::LogicalResult verify(fir::InsertOnRangeOp op) { if (fir::hasDynamicSize(op.seq().getType())) return op.emitOpError("must have constant shape and size"); - if (op.coor().size() < 2 || op.coor().size() % 2 != 0) + mlir::DenseIntElementsAttr coor = op.coor(); + if (coor.size() < 2 || coor.size() % 2 != 0) return op.emitOpError("has uneven number of values in ranges"); bool rangeIsKnownToBeNonempty = false; - for (auto i = op.coor().end(), b = op.coor().begin(); i != b;) { - int64_t ub = (*--i).cast().getInt(); - int64_t lb = (*--i).cast().getInt(); + for (auto i = coor.getValues().end(), + b = coor.getValues().begin(); + i != b;) { + int64_t ub = (*--i); + int64_t lb = (*--i); if (lb < 0 || ub < 0) return op.emitOpError("negative range bound"); if (rangeIsKnownToBeNonempty) diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -80,7 +80,7 @@ fir.global internal @_QEmultiarray : !fir.array<32x32xi32> { %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32x32xi32> - %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 31 : index, 0 : index, 31 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> + %2 = fir.insert_on_range %0, %c0_i32 from (0, 0) to (31, 31) : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> fir.has_value %2 : !fir.array<32x32xi32> } @@ -97,7 +97,7 @@ fir.global internal @_QEmultiarray : !fir.array<32xi32> { %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32xi32> - %2 = fir.insert_on_range %0, %c0_i32, [5 : index, 31 : index] : (!fir.array<32xi32>, i32) -> !fir.array<32xi32> + %2 = fir.insert_on_range %0, %c0_i32 from (5) to (31) : (!fir.array<32xi32>, i32) -> !fir.array<32xi32> fir.has_value %2 : !fir.array<32xi32> } diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir --- a/flang/test/Fir/fir-ops.fir +++ b/flang/test/Fir/fir-ops.fir @@ -617,10 +617,10 @@ %c1_i32 = arith.constant 9 : i32 // CHECK: [[ARR2:%.*]] = fir.zero_bits !fir.array<10xi32> - // CHECK: [[ARR3:%.*]] = fir.insert_on_range [[ARR2]], [[C1_I32]], [2 : index, 9 : index] : (!fir.array<10xi32>, i32) -> !fir.array<10xi32> + // CHECK: [[ARR3:%.*]] = fir.insert_on_range [[ARR2]], [[C1_I32]] from (2) to (9) : (!fir.array<10xi32>, i32) -> !fir.array<10xi32> // CHECK: fir.call @noret1([[ARR3]]) : (!fir.array<10xi32>) -> () %arr2 = fir.zero_bits !fir.array<10xi32> - %arr3 = fir.insert_on_range %arr2, %c1_i32, [2 : index, 9 : index] : (!fir.array<10xi32>, i32) -> !fir.array<10xi32> + %arr3 = fir.insert_on_range %arr2, %c1_i32 from (2) to (9) : (!fir.array<10xi32>, i32) -> !fir.array<10xi32> fir.call @noret1(%arr3) : (!fir.array<10xi32>) -> () // CHECK: [[SHAPE:%.*]] = fir.shape_shift [[INDXM:%.*]], [[INDXN:%.*]], [[INDXO:%.*]], [[INDXP:%.*]] : (index, index, index, index) -> !fir.shapeshift<2> @@ -664,6 +664,14 @@ return } +// CHECK-LABEL: @insert_on_range_multi_dim +// CHECK-SAME: %[[ARR:.*]]: !fir.array<10x20xi32>, %[[CST:.*]]: i32 +func @insert_on_range_multi_dim(%arr : !fir.array<10x20xi32>, %cst : i32) { + // CHECK: fir.insert_on_range %[[ARR]], %[[CST]] from (2, 3) to (5, 6) : (!fir.array<10x20xi32>, i32) -> !fir.array<10x20xi32> + %arr3 = fir.insert_on_range %arr, %cst from (2, 3) to (5, 6) : (!fir.array<10x20xi32>, i32) -> !fir.array<10x20xi32> + return +} + // CHECK-LABEL: @test_shift func @test_shift(%arg0: !fir.box>) -> !fir.ref { %c4 = arith.constant 4 : index diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -428,7 +428,7 @@ %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32x32xi32> // expected-error@+1 {{'fir.insert_on_range' op has uneven number of values in ranges}} - %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 31 : index, 0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> + %2 = "fir.insert_on_range"(%0, %c0_i32) { coor = dense<[0, 31, 0]> : tensor<3xindex> } : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> fir.has_value %2 : !fir.array<32x32xi32> } @@ -438,7 +438,7 @@ %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32x32xi32> // expected-error@+1 {{'fir.insert_on_range' op has uneven number of values in ranges}} - %2 = fir.insert_on_range %0, %c0_i32, [0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> + %2 = "fir.insert_on_range"(%0, %c0_i32) { coor = dense<[0]> : tensor<1xindex> } : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> fir.has_value %2 : !fir.array<32x32xi32> } @@ -448,7 +448,7 @@ %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32x32xi32> // expected-error@+1 {{'fir.insert_on_range' op negative range bound}} - %2 = fir.insert_on_range %0, %c0_i32, [-1 : index, 0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> + %2 = fir.insert_on_range %0, %c0_i32 from (-1) to (0) : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> fir.has_value %2 : !fir.array<32x32xi32> } @@ -458,7 +458,7 @@ %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32x32xi32> // expected-error@+1 {{'fir.insert_on_range' op empty range}} - %2 = fir.insert_on_range %0, %c0_i32, [10 : index, 9 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> + %2 = fir.insert_on_range %0, %c0_i32 from (10) to (9) : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> fir.has_value %2 : !fir.array<32x32xi32> } @@ -468,7 +468,7 @@ %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array // expected-error@+1 {{'fir.insert_on_range' op must have constant shape and size}} - %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 10 : index] : (!fir.array, i32) -> !fir.array + %2 = fir.insert_on_range %0, %c0_i32 from (0) to (10) : (!fir.array, i32) -> !fir.array fir.has_value %2 : !fir.array } @@ -478,7 +478,7 @@ %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<*:i32> // expected-error@+1 {{'fir.insert_on_range' op must have constant shape and size}} - %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 10 : index] : (!fir.array<*:i32>, i32) -> !fir.array<*:i32> + %2 = fir.insert_on_range %0, %c0_i32 from (0) to (10) : (!fir.array<*:i32>, i32) -> !fir.array<*:i32> fir.has_value %2 : !fir.array<*:i32> }