diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -27,6 +27,20 @@ // fir::LLVMTypeConverter for converting to LLVM IR dialect types. #include "TypeConverter.h" +namespace fir { +/// return true if all `Value`s in `operands` are `ConstantOp`s +bool allConstants(mlir::ValueRange operands) { + for (auto opnd : operands) { + if (auto *defop = opnd.getDefiningOp()) + if (isa(defop) || + isa(defop)) + continue; + return false; + } + return true; +} +} // namespace fir + namespace { /// FIR conversion pattern template template @@ -44,6 +58,27 @@ return *static_cast(this->getTypeConverter()); } }; + +/// FIR conversion pattern template +template +class FIROpAndTypeConversion : public FIROpConversion { +public: + using FIROpConversion::FIROpConversion; + using OpAdaptor = typename FromOp::Adaptor; + + mlir::LogicalResult + matchAndRewrite(FromOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const final { + mlir::Type ty = this->convertType(op.getType()); + return doRewrite(op, ty, adaptor, rewriter); + } + + virtual mlir::LogicalResult + doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + llvm_unreachable("derived class must override"); + } +}; } // namespace namespace { @@ -161,6 +196,161 @@ return success(); } }; + +// Code shared between insert_value and extract_value Ops. +struct ValueOpCommon { + static mlir::Attribute getValue(mlir::Value value) { + auto *defOp = value.getDefiningOp(); + if (auto v = dyn_cast(defOp)) + return v.value(); + if (auto v = dyn_cast(defOp)) + return v.value(); + llvm_unreachable("must be a constant op"); + return {}; + } + + // Translate the arguments pertaining to any multidimensional array to + // row-major order for LLVM-IR. + static void toRowMajor(SmallVectorImpl &attrs, + mlir::Type ty) { + assert(ty && "type is null"); + const auto end = attrs.size(); + for (std::remove_const_t i = 0; i < end; ++i) { + if (auto seq = ty.dyn_cast()) { + const auto dim = getDimension(seq); + if (dim > 1) { + auto ub = std::min(i + dim, end); + std::reverse(attrs.begin() + i, attrs.begin() + ub); + i += dim - 1; + } + ty = getArrayElementType(seq); + } else if (auto st = ty.dyn_cast()) { + ty = st.getBody()[attrs[i].cast().getInt()]; + } else { + llvm_unreachable("index into invalid type"); + } + } + } + + static llvm::SmallVector + collectIndices(mlir::ConversionPatternRewriter &rewriter, + mlir::ArrayAttr arrAttr) { + llvm::SmallVector attrs; + for (auto i = arrAttr.begin(), e = arrAttr.end(); i != e; ++i) { + if (i->isa()) { + attrs.push_back(*i); + } else { + auto fieldName = i->cast().getValue(); + ++i; + auto ty = i->cast().getValue(); + auto index = ty.cast().getFieldIndex(fieldName); + attrs.push_back(mlir::IntegerAttr::get(rewriter.getI32Type(), index)); + } + } + return attrs; + } + +private: + static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) { + unsigned result = 1; + for (auto eleTy = ty.getElementType().dyn_cast(); + eleTy; + eleTy = eleTy.getElementType().dyn_cast()) + ++result; + return result; + } + + static mlir::Type getArrayElementType(mlir::LLVM::LLVMArrayType ty) { + auto eleTy = ty.getElementType(); + while (auto arrTy = eleTy.dyn_cast()) + eleTy = arrTy.getElementType(); + return eleTy; + } +}; + +/// InsertOnRange inserts a value into a sequence over a range of offsets. +struct InsertOnRangeOpConversion + : public FIROpAndTypeConversion, + public ValueOpCommon { + using FIROpAndTypeConversion::FIROpAndTypeConversion; + + // Increments an array of subscripts in a row major fasion. + void incrementSubscripts(const SmallVector &dims, + SmallVector &subscripts) const { + for (size_t i = dims.size(); i > 0; --i) { + if (++subscripts[i - 1] < dims[i - 1]) { + return; + } + subscripts[i - 1] = 0; + } + } + + mlir::LogicalResult + doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + assert(fir::allConstants(adaptor.getOperands().drop_front(2))); + + llvm::SmallVector lowerBound; + llvm::SmallVector upperBound; + llvm::SmallVector dims; + auto type = adaptor.getOperands()[0].getType(); + + // Iteratively extract the array dimensions from the type. + while (auto t = type.dyn_cast()) { + dims.push_back(t.getNumElements()); + type = t.getElementType(); + } + + // Unzip the upper and lower bound subscripts. + for (auto i = range.coor().begin(), e = range.coor().end(); i != e; ++i) { + lowerBound.push_back(*i++); + upperBound.push_back(*i); + } + + SmallVector lBounds; + SmallVector uBounds; + + // Extract the integer value from the attribute bounds and convert to row + // major format. + for (std::size_t i = lowerBound.size(); i > 0; --i) { + lBounds.push_back(lowerBound[i - 1].cast().getInt()); + uBounds.push_back(upperBound[i - 1].cast().getInt()); + } + + auto subscripts(lBounds); + auto loc = range.getLoc(); + mlir::Value lastOp = adaptor.getOperands()[0]; + mlir::Value insertVal = adaptor.getOperands()[1]; + + while (subscripts != uBounds) { + // Convert uint64_t's to Attribute's. + SmallVector subscriptAttrs; + for (const auto &subscript : subscripts) + subscriptAttrs.push_back( + IntegerAttr::get(rewriter.getI64Type(), subscript)); + mlir::ArrayRef arrayRef(subscriptAttrs); + lastOp = rewriter.create( + loc, ty, lastOp, insertVal, + ArrayAttr::get(range.getContext(), arrayRef)); + + incrementSubscripts(dims, subscripts); + } + + // Convert uint64_t's to Attribute's. + SmallVector subscriptAttrs; + for (const auto &subscript : subscripts) + subscriptAttrs.push_back( + IntegerAttr::get(rewriter.getI64Type(), subscript)); + mlir::ArrayRef arrayRef(subscriptAttrs); + + rewriter.replaceOpWithNewOp( + range, ty, lastOp, insertVal, + ArrayAttr::get(range.getContext(), arrayRef)); + + return success(); + } +}; + } // namespace namespace { @@ -180,7 +370,7 @@ auto loc = mlir::UnknownLoc::get(context); mlir::OwningRewritePatternList pattern(context); pattern.insert(typeConverter); + InsertOnRangeOpConversion, UndefOpConversion>(typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -81,3 +81,22 @@ // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<1> : vector<32x32xi32>) : !llvm.array<32 x array<32 x i32>> // CHECK: llvm.return %[[CST]] : !llvm.array<32 x array<32 x i32>> // CHECK: } + +// ----- + +// Test global with insert_on_range operation not covering the full array +// in initializer region. + +fir.global internal @_QEmultiarray : !fir.array<32x32xi32> { + %c0_i32 = arith.constant 1 : i32 + %0 = fir.undefined !fir.array<32x32xi32> + %2 = fir.insert_on_range %0, %c0_i32, [2 : index, 4 : index, 30 : index, 31 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> + fir.has_value %2 : !fir.array<32x32xi32> +} + +// CHECK: llvm.mlir.global internal @_QEmultiarray() : !llvm.array<32 x array<32 x i32>> { +// CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.array<32 x array<32 x i32>> +// CHECK-COUNT-35: %{{.*}} = llvm.insertvalue %[[CST]], %{{.*}}[{{.*}}, {{.*}}] : !llvm.array<32 x array<32 x i32>> +// CHECK: llvm.return %{{.*}} : !llvm.array<32 x array<32 x i32>> +// CHECK: }