diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -291,6 +291,75 @@ }]; } +def fir_AllocMemOp : fir_Op<"allocmem", + [MemoryEffects<[MemAlloc]>, AttrSizedOperandSegments]> { + let summary = "allocate storage on the heap for an object of a given type"; + + let description = [{ + Creates a heap memory reference suitable for storing a value of the + given type, T. The heap refernce returned has type `!fir.heap`. + The memory object is in an undefined state. `allocmem` operations must + be paired with `freemem` operations to avoid memory leaks. + + ```mlir + %0 = fir.allocmem !fir.array<10 x f32> + fir.freemem %0 : !fir.heap> + ``` + }]; + + let arguments = (ins + TypeAttr:$in_type, + OptionalAttr:$uniq_name, + OptionalAttr:$bindc_name, + Variadic:$typeparams, + Variadic:$shape + ); + let results = (outs fir_HeapType); + + let parser = "return parseAllocMem(parser, result);"; + let printer = "printAllocMem(p, *this);"; + + let builders = [ + OpBuilder<(ins "mlir::Type":$in_type, "llvm::StringRef":$uniq_name, + "llvm::StringRef":$bindc_name, CArg<"mlir::ValueRange", "{}">:$typeparams, + CArg<"mlir::ValueRange", "{}">:$shape, + CArg<"llvm::ArrayRef", "{}">:$attributes)>, + OpBuilder<(ins "mlir::Type":$in_type, "llvm::StringRef":$uniq_name, + CArg<"mlir::ValueRange", "{}">:$typeparams, + CArg<"mlir::ValueRange", "{}">:$shape, + CArg<"llvm::ArrayRef", "{}">:$attributes)>, + OpBuilder<(ins "mlir::Type":$in_type, + CArg<"mlir::ValueRange", "{}">:$typeparams, + CArg<"mlir::ValueRange", "{}">:$shape, + CArg<"llvm::ArrayRef", "{}">:$attributes)>]; + + let verifier = [{ + llvm::SmallVector visited; + if (verifyInType(getInType(), visited, numShapeOperands())) + return emitOpError("invalid type for allocation"); + if (verifyTypeParamCount(getInType(), numLenParams())) + return emitOpError("LEN params do not correspond to type"); + mlir::Type outType = getType(); + if (!outType.dyn_cast()) + return emitOpError("must be a !fir.heap type"); + if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType))) + return emitOpError("cannot allocate !fir.box of unknown rank or type"); + return mlir::success(); + }]; + + let extraClassDeclaration = [{ + mlir::Type getAllocatedType(); + bool hasLenParams() { return !typeparams().empty(); } + bool hasShapeOperands() { return !shape().empty(); } + unsigned numLenParams() { return typeparams().size(); } + operand_range getLenParams() { return typeparams(); } + unsigned numShapeOperands() { return shape().size(); } + operand_range getShapeOperands() { return shape(); } + static mlir::Type getRefTy(mlir::Type ty); + mlir::Type getInType() { return in_type(); } + }]; +} + def fir_LoadOp : fir_OneResultOp<"load"> { let summary = "load a value from a memory reference"; let description = [{ @@ -455,37 +524,6 @@ let assemblyFormat = "type($intype) attr-dict"; } -def fir_AllocMemOp : fir_AllocatableOp<"allocmem", DefaultResource> { - let summary = "allocate storage on the heap for an object of a given type"; - - let description = [{ - Creates a heap memory reference suitable for storing a value of the - given type, T. The heap refernce returned has type `!fir.heap`. - The memory object is in an undefined state. `allocmem` operations must - be paired with `freemem` operations to avoid memory leaks. - - ```mlir - %0 = fir.allocmem !fir.array<10 x f32> - fir.freemem %0 : !fir.heap> - ``` - }]; - - let results = (outs fir_HeapType); - - let verifier = allocVerify#[{ - mlir::Type outType = getType(); - if (!outType.dyn_cast()) - return emitOpError("must be a !fir.heap type"); - if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType))) - return emitOpError("cannot allocate !fir.box of unknown rank or type"); - return mlir::success(); - }]; - - let extraClassDeclaration = extraAllocClassDeclaration#[{ - static mlir::Type wrapResultType(mlir::Type intype); - }]; -} - def fir_FreeMemOp : fir_Op<"freemem", [MemoryEffects<[MemFree]>]> { let summary = "free a heap object"; diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h --- a/flang/include/flang/Optimizer/Dialect/FIRType.h +++ b/flang/include/flang/Optimizer/Dialect/FIRType.h @@ -134,6 +134,13 @@ /// of unknown rank or type. bool isa_unknown_size_box(mlir::Type t); +/// If `t` is a SequenceType return its element type, otherwise return `t`. +inline mlir::Type unwrapSequenceType(mlir::Type t) { + if (auto seqTy = t.dyn_cast()) + return seqTy.getEleTy(); + return t; +} + #ifndef NDEBUG // !fir.ptr and !fir.heap where X is !fir.ptr, !fir.heap, or !fir.ref // is undefined and disallowed. diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -68,6 +68,84 @@ return false; } +static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) { + auto ty = fir::unwrapSequenceType(inType); + if (numParams > 0) { + if (auto recTy = ty.dyn_cast()) + return numParams != recTy.getNumLenParams(); + if (auto chrTy = ty.dyn_cast()) + return !(numParams == 1 && chrTy.hasDynamicLen()); + return true; + } + if (auto chrTy = ty.dyn_cast()) + return !chrTy.hasConstantLen(); + return false; +} + +// Parser shared by Alloca and Allocmem +template +static mlir::ParseResult parseAllocatableOp(FN wrapResultType, + mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::Type intype; + if (parser.parseType(intype)) + return mlir::failure(); + auto &builder = parser.getBuilder(); + result.addAttribute("in_type", mlir::TypeAttr::get(intype)); + llvm::SmallVector operands; + llvm::SmallVector typeVec; + bool hasOperands = false; + std::int32_t typeparamsSize = 0; + if (!parser.parseOptionalLParen()) { + // parse the LEN params of the derived type. ( : ) + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || + parser.parseColonTypeList(typeVec) || parser.parseRParen()) + return mlir::failure(); + typeparamsSize = operands.size(); + hasOperands = true; + } + std::int32_t shapeSize = 0; + if (!parser.parseOptionalComma()) { + // parse size to scale by, vector of n dimensions of type index + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None)) + return mlir::failure(); + shapeSize = operands.size() - typeparamsSize; + auto idxTy = builder.getIndexType(); + for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) + typeVec.push_back(idxTy); + hasOperands = true; + } + if (hasOperands && + parser.resolveOperands(operands, typeVec, parser.getNameLoc(), + result.operands)) + return mlir::failure(); + mlir::Type restype = wrapResultType(intype); + if (!restype) { + parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype; + return mlir::failure(); + } + result.addAttribute("operand_segment_sizes", + builder.getI32VectorAttr({typeparamsSize, shapeSize})); + if (parser.parseOptionalAttrDict(result.attributes) || + parser.addTypeToList(restype, result.types)) + return mlir::failure(); + return mlir::success(); +} + +template +static void printAllocatableOp(mlir::OpAsmPrinter &p, OP &op) { + p << ' ' << op.in_type(); + if (!op.typeparams().empty()) { + p << '(' << op.typeparams() << " : " << op.typeparams().getTypes() << ')'; + } + // print the shape of the allocation (if any); all must be index type + for (auto sh : op.shape()) { + p << ", "; + p.printOperand(sh); + } + p.printOptionalAttrDict(op->getAttrs(), {"in_type", "operand_segment_sizes"}); +} + //===----------------------------------------------------------------------===// // AllocaOp //===----------------------------------------------------------------------===// @@ -92,16 +170,8 @@ // AllocMemOp //===----------------------------------------------------------------------===// -mlir::Type fir::AllocMemOp::getAllocatedType() { - return getType().cast().getEleTy(); -} - -mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { - return HeapType::get(ty); -} - /// Create a legal heap reference as return type -mlir::Type fir::AllocMemOp::wrapResultType(mlir::Type intype) { +static mlir::Type wrapAllocMemResultType(mlir::Type intype) { // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well // FIR semantics: one may not allocate a memory reference value @@ -111,6 +181,56 @@ return HeapType::get(intype); } +static mlir::ParseResult parseAllocMem(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseAllocatableOp(wrapAllocMemResultType, parser, result); +} + +static void printAllocMem(mlir::OpAsmPrinter &p, fir::AllocMemOp &op) { + printAllocatableOp(p, op); +} + +mlir::Type fir::AllocMemOp::getAllocatedType() { + return getType().cast().getEleTy(); +} + +mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { + return HeapType::get(ty); +} + +void fir::AllocMemOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, mlir::Type in_type, + llvm::StringRef uniq_name, + mlir::ValueRange typeparams, mlir::ValueRange shape, + llvm::ArrayRef attributes) { + auto nameAttr = builder.getStringAttr(uniq_name); + build(builder, result, wrapAllocMemResultType(in_type), in_type, nameAttr, {}, + typeparams, shape); + result.addAttributes(attributes); +} + +void fir::AllocMemOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, mlir::Type in_type, + llvm::StringRef uniq_name, + llvm::StringRef bindc_name, + mlir::ValueRange typeparams, mlir::ValueRange shape, + llvm::ArrayRef attributes) { + auto nameAttr = builder.getStringAttr(uniq_name); + auto bindcAttr = builder.getStringAttr(bindc_name); + build(builder, result, wrapAllocMemResultType(in_type), in_type, nameAttr, + bindcAttr, typeparams, shape); + result.addAttributes(attributes); +} + +void fir::AllocMemOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, mlir::Type in_type, + mlir::ValueRange typeparams, mlir::ValueRange shape, + llvm::ArrayRef attributes) { + build(builder, result, wrapAllocMemResultType(in_type), in_type, {}, {}, + typeparams, shape); + result.addAttributes(attributes); +} + //===----------------------------------------------------------------------===// // ArrayCoorOp //===----------------------------------------------------------------------===//