diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -19,6 +19,7 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/ArrayRef.h" @@ -44,6 +45,25 @@ return *static_cast(this->getTypeConverter()); } }; + +/// FIR conversion pattern template +template +class FIROpAndTypeConversion : public FIROpConversion { +public: + using FIROpConversion::FIROpConversion; + using OpAdaptor = typename FromOp::Adaptor; + + mlir::LogicalResult + matchAndRewrite(FromOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const final { + mlir::Type ty = this->convertType(op.getType()); + return doRewrite(op, ty, adaptor, rewriter); + } + + virtual mlir::LogicalResult + doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const = 0; +}; } // namespace namespace { @@ -161,6 +181,84 @@ return success(); } }; + +/// InsertOnRange inserts a value into a sequence over a range of offsets. +struct InsertOnRangeOpConversion + : public FIROpAndTypeConversion { + using FIROpAndTypeConversion::FIROpAndTypeConversion; + + // Increments an array of subscripts in a row major fasion. + void incrementSubscripts(const SmallVector &dims, + SmallVector &subscripts) const { + for (size_t i = dims.size(); i > 0; --i) { + if (++subscripts[i - 1] < dims[i - 1]) { + return; + } + subscripts[i - 1] = 0; + } + } + + mlir::LogicalResult + doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + llvm::SmallVector dims; + auto type = adaptor.getOperands()[0].getType(); + + // Iteratively extract the array dimensions from the type. + while (auto t = type.dyn_cast()) { + dims.push_back(t.getNumElements()); + type = t.getElementType(); + } + + SmallVector lBounds; + SmallVector uBounds; + + // Extract integer value from the attribute + SmallVector coordinates = llvm::to_vector<4>( + llvm::map_range(range.coor(), [](Attribute a) -> int64_t { + return a.cast().getInt(); + })); + + // Unzip the upper and lower bound and convert to a row major format. + for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) { + uBounds.push_back(*i++); + lBounds.push_back(*i); + } + + auto &subscripts = lBounds; + auto loc = range.getLoc(); + mlir::Value lastOp = adaptor.getOperands()[0]; + mlir::Value insertVal = adaptor.getOperands()[1]; + + auto i64Ty = rewriter.getI64Type(); + while (subscripts != uBounds) { + // Convert uint64_t's to Attribute's. + SmallVector subscriptAttrs; + for (const auto &subscript : subscripts) + subscriptAttrs.push_back(IntegerAttr::get(i64Ty, subscript)); + lastOp = rewriter.create( + loc, ty, lastOp, insertVal, + ArrayAttr::get(range.getContext(), subscriptAttrs)); + + incrementSubscripts(dims, subscripts); + } + + // Convert uint64_t's to Attribute's. + SmallVector subscriptAttrs; + for (const auto &subscript : subscripts) + subscriptAttrs.push_back( + IntegerAttr::get(rewriter.getI64Type(), subscript)); + mlir::ArrayRef arrayRef(subscriptAttrs); + + rewriter.replaceOpWithNewOp( + range, ty, lastOp, insertVal, + ArrayAttr::get(range.getContext(), arrayRef)); + + return success(); + } +}; + } // namespace namespace { @@ -180,7 +278,7 @@ auto loc = mlir::UnknownLoc::get(context); mlir::OwningRewritePatternList pattern(context); pattern.insert(typeConverter); + InsertOnRangeOpConversion, UndefOpConversion>(typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -81,3 +81,25 @@ // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<1> : vector<32x32xi32>) : !llvm.array<32 x array<32 x i32>> // CHECK: llvm.return %[[CST]] : !llvm.array<32 x array<32 x i32>> // CHECK: } + +// ----- + +// Test global with insert_on_range operation not covering the full array +// in initializer region. + +fir.global internal @_QEmultiarray : !fir.array<32xi32> { + %c0_i32 = arith.constant 1 : i32 + %0 = fir.undefined !fir.array<32xi32> + %2 = fir.insert_on_range %0, %c0_i32, [5 : index, 31 : index] : (!fir.array<32xi32>, i32) -> !fir.array<32xi32> + fir.has_value %2 : !fir.array<32xi32> +} + +// CHECK: llvm.mlir.global internal @_QEmultiarray() : !llvm.array<32 x i32> { +// CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.array<32 x i32> +// CHECK: %{{.*}} = llvm.insertvalue %[[CST]], %{{.*}}[5] : !llvm.array<32 x i32> +// CHECK-COUNT-24: %{{.*}} = llvm.insertvalue %[[CST]], %{{.*}}[{{.*}}] : !llvm.array<32 x i32> +// CHECK: %{{.*}} = llvm.insertvalue %[[CST]], %{{.*}}[31] : !llvm.array<32 x i32> +// CHECK-NOT: llvm.insertvalue +// CHECK: llvm.return %{{.*}} : !llvm.array<32 x i32> +// CHECK: }