diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -1987,18 +1987,18 @@ ```mlir %a = fir.undefined !fir.array<10x10xf32> %c = arith.constant 3.0 : f32 - %1 = fir.insert_on_range %a, %c, [0 : index, 7 : index, 0 : index, 2 : index] : (!fir.array<10x10xf32>, f32) -> !fir.array<10x10xf32> + %1 = fir.insert_on_range %a, %c from [0, 0] to [7, 2] : (!fir.array<10x10xf32>, f32) -> !fir.array<10x10xf32> ``` The first 28 elements of %1, with coordinates from (0,0) to (7,2), have the value 3.0. }]; - let arguments = (ins fir_SequenceType:$seq, AnyType:$val, ArrayAttr:$coor); + let arguments = (ins fir_SequenceType:$seq, AnyType:$val, IndexElementsAttr:$coor); let results = (outs fir_SequenceType); let assemblyFormat = [{ - $seq `,` $val `,` $coor attr-dict `:` functional-type(operands, results) + $seq `,` $val custom($coor) attr-dict `:` functional-type(operands, results) }]; let verifier = "return ::verify(*this);"; diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -123,14 +123,16 @@ return success(); } - bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const { + bool isFullRange(mlir::DenseIntElementsAttr indexes, + fir::SequenceType seqTy) const { auto extents = seqTy.getShape(); - if (indexes.size() / 2 != extents.size()) + if (indexes.size() / 2 != static_cast(extents.size())) return false; + auto cur_index = indexes.value_begin(); for (unsigned i = 0; i < indexes.size(); i += 2) { - if (indexes[i].cast().getInt() != 0) + if (*(cur_index++) != 0) return false; - if (indexes[i + 1].cast().getInt() != extents[i / 2] - 1) + if (*(cur_index++) != extents[i / 2] - 1) return false; } return true; diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -21,8 +21,13 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" +#include +#include +#include namespace { #include "flang/Optimizer/Dialect/CanonicalizationPatterns.inc" @@ -1385,14 +1390,60 @@ // InsertOnRangeOp //===----------------------------------------------------------------------===// +static ParseResult +parseCustomRangeSubscript(mlir::OpAsmParser &parser, + mlir::DenseIntElementsAttr &coord) { + llvm::SmallVector lbounds; + llvm::SmallVector ubounds; + if (parser.parseKeyword("from") || + parser.parseCommaSeparatedList( + AsmParser::Delimiter::Square, + [&] { return parser.parseInteger(lbounds.emplace_back(0)); }) || + parser.parseKeyword("to") || + parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&] { + return parser.parseInteger(ubounds.emplace_back(0)); + })) + return failure(); + llvm::SmallVector zippedBounds; + for (auto zip : llvm::zip(lbounds, ubounds)) { + zippedBounds.push_back(std::get<0>(zip)); + zippedBounds.push_back(std::get<1>(zip)); + } + coord = mlir::Builder(parser.getContext()).getIndexTensorAttr(zippedBounds); + return success(); +} + +void printCustomRangeSubscript(mlir::OpAsmPrinter &printer, InsertOnRangeOp op, + mlir::DenseIntElementsAttr coord) { + printer << "from ["; + auto enumerate = llvm::enumerate(coord.getValues()); + // Even entries are the lower bounds. + llvm::interleaveComma( + make_filter_range( + enumerate, + [](auto indexed_value) { return indexed_value.index() % 2 == 0; }), + printer, [&](auto indexed_value) { printer << indexed_value.value(); }); + printer << "] to ["; + // Odd entries are the upper bounds. + llvm::interleaveComma( + make_filter_range( + enumerate, + [](auto indexed_value) { return indexed_value.index() % 2 != 0; }), + printer, [&](auto indexed_value) { printer << indexed_value.value(); }); + printer << "]"; +} + /// Range bounds must be nonnegative, and the range must not be empty. static mlir::LogicalResult verify(fir::InsertOnRangeOp op) { - if (op.coor().size() < 2 || op.coor().size() % 2 != 0) + mlir::DenseIntElementsAttr coor = op.coor(); + if (coor.size() < 2 || coor.size() % 2 != 0) return op.emitOpError("has uneven number of values in ranges"); bool rangeIsKnownToBeNonempty = false; - for (auto i = op.coor().end(), b = op.coor().begin(); i != b;) { - int64_t ub = (*--i).cast().getInt(); - int64_t lb = (*--i).cast().getInt(); + for (auto i = coor.getValues().end(), + b = coor.getValues().begin(); + i != b;) { + int64_t ub = (*--i); + int64_t lb = (*--i); if (lb < 0 || ub < 0) return op.emitOpError("negative range bound"); if (rangeIsKnownToBeNonempty) diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -73,7 +73,7 @@ fir.global internal @_QEmultiarray : !fir.array<32x32xi32> { %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32x32xi32> - %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 31 : index, 0 : index, 31 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> + %2 = fir.insert_on_range %0, %c0_i32 from [0, 0] to [31, 31] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> fir.has_value %2 : !fir.array<32x32xi32> } diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir --- a/flang/test/Fir/fir-ops.fir +++ b/flang/test/Fir/fir-ops.fir @@ -619,10 +619,10 @@ %c1_i32 = arith.constant 9 : i32 // CHECK: [[ARR2:%.*]] = fir.zero_bits !fir.array<10xi32> - // CHECK: [[ARR3:%.*]] = fir.insert_on_range [[ARR2]], [[C1_I32]], [2 : index, 9 : index] : (!fir.array<10xi32>, i32) -> !fir.array<10xi32> + // CHECK: [[ARR3:%.*]] = fir.insert_on_range [[ARR2]], [[C1_I32]] from [2] to [9] : (!fir.array<10xi32>, i32) -> !fir.array<10xi32> // CHECK: fir.call @noret1([[ARR3]]) : (!fir.array<10xi32>) -> () %arr2 = fir.zero_bits !fir.array<10xi32> - %arr3 = fir.insert_on_range %arr2, %c1_i32, [2 : index, 9 : index] : (!fir.array<10xi32>, i32) -> !fir.array<10xi32> + %arr3 = fir.insert_on_range %arr2, %c1_i32 from [2] to [9] : (!fir.array<10xi32>, i32) -> !fir.array<10xi32> fir.call @noret1(%arr3) : (!fir.array<10xi32>) -> () // CHECK: [[SHAPE:%.*]] = fir.shape_shift [[INDXM:%.*]], [[INDXN:%.*]], [[INDXO:%.*]], [[INDXP:%.*]] : (index, index, index, index) -> !fir.shapeshift<2> @@ -649,6 +649,14 @@ return } +// CHECK-LABEL: @insert_on_range_multi_dim +// CHECK-SAME: %[[ARR:.*]]: !fir.array<10x20xi32>, %[[CST:.*]]: i32 +func @insert_on_range_multi_dim(%arr : !fir.array<10x20xi32>, %cst : i32) { + // CHECK: fir.insert_on_range %[[ARR]], %[[CST]] from [2, 3] to [5, 6] : (!fir.array<10x20xi32>, i32) -> !fir.array<10x20xi32> + %arr3 = fir.insert_on_range %arr, %cst from [2, 3] to [5, 6] : (!fir.array<10x20xi32>, i32) -> !fir.array<10x20xi32> + return +} + // CHECK-LABEL: @test_shift func @test_shift(%arg0: !fir.box>) -> !fir.ref { %c4 = arith.constant 4 : index diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -428,7 +428,7 @@ %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32x32xi32> // expected-error@+1 {{'fir.insert_on_range' op has uneven number of values in ranges}} - %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 31 : index, 0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> + %2 = "fir.insert_on_range"(%0, %c0_i32) { coor = dense<[0, 31, 0]> : tensor<3xindex> } : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> fir.has_value %2 : !fir.array<32x32xi32> } @@ -438,7 +438,7 @@ %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32x32xi32> // expected-error@+1 {{'fir.insert_on_range' op has uneven number of values in ranges}} - %2 = fir.insert_on_range %0, %c0_i32, [0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> + %2 = "fir.insert_on_range"(%0, %c0_i32) { coor = dense<[0]> : tensor<1xindex> } : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> fir.has_value %2 : !fir.array<32x32xi32> } @@ -448,7 +448,7 @@ %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32x32xi32> // expected-error@+1 {{'fir.insert_on_range' op negative range bound}} - %2 = fir.insert_on_range %0, %c0_i32, [-1 : index, 0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> + %2 = fir.insert_on_range %0, %c0_i32 from [-1] to [0] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> fir.has_value %2 : !fir.array<32x32xi32> } @@ -458,7 +458,7 @@ %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32x32xi32> // expected-error@+1 {{'fir.insert_on_range' op empty range}} - %2 = fir.insert_on_range %0, %c0_i32, [10 : index, 9 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> + %2 = fir.insert_on_range %0, %c0_i32 from [10] to [9] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> fir.has_value %2 : !fir.array<32x32xi32> }