diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -1716,8 +1716,8 @@ "mlir::Type":$recTy, CArg<"mlir::ValueRange","{}">:$operands)>]; let extraClassDeclaration = [{ - static constexpr llvm::StringRef fieldAttrName() { return "field_id"; } - static constexpr llvm::StringRef typeAttrName() { return "on_type"; } + static constexpr llvm::StringRef getFieldAttrName() { return "field_id"; } + static constexpr llvm::StringRef getTypeAttrName() { return "on_type"; } llvm::StringRef getFieldName() { return getFieldId(); } llvm::SmallVector getAttributes(); }]; @@ -1970,21 +1970,22 @@ ``` }]; - let arguments = (ins StrAttr:$field_id, TypeAttr:$on_type); + let arguments = (ins + StrAttr:$field_id, + TypeAttr:$on_type, + Variadic:$typeparams + ); let hasCustomAssemblyFormat = 1; let builders = [OpBuilder<(ins "llvm::StringRef":$fieldName, - "mlir::Type":$recTy), - [{ - $_state.addAttribute(fieldAttrName(), $_builder.getStringAttr(fieldName)); - $_state.addAttribute(typeAttrName(), mlir::TypeAttr::get(recTy)); - }] - >]; + "mlir::Type":$recTy, CArg<"mlir::ValueRange","{}">:$operands)>]; let extraClassDeclaration = [{ - static constexpr llvm::StringRef fieldAttrName() { return "field_id"; } - static constexpr llvm::StringRef typeAttrName() { return "on_type"; } + static constexpr llvm::StringRef getFieldAttrName() { return "field_id"; } + static constexpr llvm::StringRef getTypeAttrName() { return "on_type"; } + llvm::StringRef getParamName() { return getFieldId(); } + llvm::SmallVector getAttributes(); }]; } diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp --- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp +++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp @@ -307,7 +307,7 @@ auto toTy = typeConverter.convertType(ty); auto toOnTy = typeConverter.convertType(onTy); rewriter.replaceOpWithNewOp( - mem, toTy, index.getFieldId(), toOnTy); + mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); } } else if (op->getDialect() == firDialect) { rewriter.startRootUpdate(op); diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -924,7 +924,7 @@ } mlir::LogicalResult fir::CoordinateOp::verify() { - auto refTy = getRef().getType(); + const mlir::Type refTy = getRef().getType(); if (fir::isa_ref_type(refTy)) { auto eleTy = fir::dyn_cast_ptrEleTy(refTy); if (auto arrTy = eleTy.dyn_cast()) { @@ -935,18 +935,70 @@ } if (!(fir::isa_aggregate(eleTy) || fir::isa_complex(eleTy) || fir::isa_char_string(eleTy))) - return emitOpError("cannot apply coordinate_of to this type"); - } - // Recovering a LEN type parameter only makes sense from a boxed value. For a - // bare reference, the LEN type parameters must be passed as additional - // arguments to `op`. - for (auto co : getCoor()) - if (mlir::dyn_cast_or_null(co.getDefiningOp())) { - if (getNumOperands() != 2) - return emitOpError("len_param_index must be last argument"); - if (!getRef().getType().isa()) - return emitOpError("len_param_index must be used on box type"); + return emitOpError("cannot apply to this element type"); + } + auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(refTy); + unsigned dimension = 0; + const unsigned numCoors = getCoor().size(); + for (auto coorOperand : llvm::enumerate(getCoor())) { + auto co = coorOperand.value(); + if (dimension == 0 && eleTy.isa()) { + dimension = eleTy.cast().getDimension(); + if (dimension == 0) + return emitOpError("cannot apply to array of unknown rank"); } + if (auto *defOp = co.getDefiningOp()) { + if (auto index = mlir::dyn_cast(defOp)) { + // Recovering a LEN type parameter only makes sense from a boxed + // value. For a bare reference, the LEN type parameters must be + // passed as additional arguments to `index`. + if (refTy.isa()) { + if (coorOperand.index() != numCoors - 1) + return emitOpError("len_param_index must be last argument"); + if (getNumOperands() != 2) + return emitOpError("too many operands for len_param_index case"); + } + if (eleTy != index.getOnType()) + emitOpError( + "len_param_index type not compatible with reference type"); + return mlir::success(); + } else if (auto index = mlir::dyn_cast(defOp)) { + if (eleTy != index.getOnType()) + emitOpError("field_index type not compatible with reference type"); + if (auto recTy = eleTy.dyn_cast()) { + eleTy = recTy.getType(index.getFieldName()); + continue; + } + return emitOpError("field_index not applied to !fir.type"); + } + } + if (dimension) { + if (--dimension == 0) + eleTy = eleTy.cast().getEleTy(); + } else { + if (auto t = eleTy.dyn_cast()) { + // FIXME: Generally, we don't know which field of the tuple is being + // referred to unless the operand is a constant. Just assume everything + // is good in the tuple case for now. + return mlir::success(); + } else if (auto t = eleTy.dyn_cast()) { + // FIXME: This is the same as the tuple case. + return mlir::success(); + } else if (auto t = eleTy.dyn_cast()) { + eleTy = t.getElementType(); + } else if (auto t = eleTy.dyn_cast()) { + eleTy = t.getElementType(); + } else if (auto t = eleTy.dyn_cast()) { + if (t.getLen() == fir::CharacterType::singleton()) + return emitOpError("cannot apply to character singleton"); + eleTy = fir::CharacterType::getSingleton(t.getContext(), t.getFKind()); + if (fir::unwrapRefType(getType()) != eleTy) + return emitOpError("character type mismatch"); + } else { + return emitOpError("invalid parameters (too many)"); + } + } + } return mlir::success(); } @@ -1331,19 +1383,20 @@ // FieldIndexOp //===----------------------------------------------------------------------===// -mlir::ParseResult fir::FieldIndexOp::parse(mlir::OpAsmParser &parser, - mlir::OperationState &result) { +template +mlir::ParseResult parseFieldLikeOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { llvm::StringRef fieldName; auto &builder = parser.getBuilder(); mlir::Type recty; if (parser.parseOptionalKeyword(&fieldName) || parser.parseComma() || parser.parseType(recty)) return mlir::failure(); - result.addAttribute(fir::FieldIndexOp::fieldAttrName(), + result.addAttribute(fir::FieldIndexOp::getFieldAttrName(), builder.getStringAttr(fieldName)); if (!recty.dyn_cast()) return mlir::failure(); - result.addAttribute(fir::FieldIndexOp::typeAttrName(), + result.addAttribute(fir::FieldIndexOp::getTypeAttrName(), mlir::TypeAttr::get(recty)); if (!parser.parseOptionalLParen()) { llvm::SmallVector operands; @@ -1354,23 +1407,30 @@ parser.resolveOperands(operands, types, loc, result.operands)) return mlir::failure(); } - mlir::Type fieldType = fir::FieldType::get(builder.getContext()); + mlir::Type fieldType = TY::get(builder.getContext()); if (parser.addTypeToList(fieldType, result.types)) return mlir::failure(); return mlir::success(); } -void fir::FieldIndexOp::print(mlir::OpAsmPrinter &p) { +mlir::ParseResult fir::FieldIndexOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseFieldLikeOp(parser, result); +} + +template +void printFieldLikeOp(mlir::OpAsmPrinter &p, OP &op) { p << ' ' - << getOperation() - ->getAttrOfType(fir::FieldIndexOp::fieldAttrName()) + << op.getOperation() + ->template getAttrOfType( + fir::FieldIndexOp::getFieldAttrName()) .getValue() - << ", " << getOperation()->getAttr(fir::FieldIndexOp::typeAttrName()); - if (getNumOperands()) { + << ", " << op.getOperation()->getAttr(fir::FieldIndexOp::getTypeAttrName()); + if (op.getNumOperands()) { p << '('; - p.printOperands(getTypeparams()); - const auto *sep = ") : "; - for (auto op : getTypeparams()) { + p.printOperands(op.getTypeparams()); + auto sep = ") : "; + for (auto op : op.getTypeparams()) { p << sep; if (op) p.printType(op.getType()); @@ -1381,12 +1441,16 @@ } } +void fir::FieldIndexOp::print(mlir::OpAsmPrinter &p) { + printFieldLikeOp(p, *this); +} + void fir::FieldIndexOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, llvm::StringRef fieldName, mlir::Type recTy, mlir::ValueRange operands) { - result.addAttribute(fieldAttrName(), builder.getStringAttr(fieldName)); - result.addAttribute(typeAttrName(), mlir::TypeAttr::get(recTy)); + result.addAttribute(getFieldAttrName(), builder.getStringAttr(fieldName)); + result.addAttribute(getTypeAttrName(), mlir::TypeAttr::get(recTy)); result.addOperands(operands); } @@ -1767,31 +1831,27 @@ mlir::ParseResult fir::LenParamIndexOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { - llvm::StringRef fieldName; - auto &builder = parser.getBuilder(); - mlir::Type recty; - if (parser.parseOptionalKeyword(&fieldName) || parser.parseComma() || - parser.parseType(recty)) - return mlir::failure(); - result.addAttribute(fir::LenParamIndexOp::fieldAttrName(), - builder.getStringAttr(fieldName)); - if (!recty.dyn_cast()) - return mlir::failure(); - result.addAttribute(fir::LenParamIndexOp::typeAttrName(), - mlir::TypeAttr::get(recty)); - mlir::Type lenType = fir::LenType::get(builder.getContext()); - if (parser.addTypeToList(lenType, result.types)) - return mlir::failure(); - return mlir::success(); + return parseFieldLikeOp(parser, result); } void fir::LenParamIndexOp::print(mlir::OpAsmPrinter &p) { - p << ' ' - << getOperation() - ->getAttrOfType( - fir::LenParamIndexOp::fieldAttrName()) - .getValue() - << ", " << getOperation()->getAttr(fir::LenParamIndexOp::typeAttrName()); + printFieldLikeOp(p, *this); +} + +void fir::LenParamIndexOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, + llvm::StringRef fieldName, mlir::Type recTy, + mlir::ValueRange operands) { + result.addAttribute(getFieldAttrName(), builder.getStringAttr(fieldName)); + result.addAttribute(getTypeAttrName(), mlir::TypeAttr::get(recTy)); + result.addOperands(operands); +} + +llvm::SmallVector fir::LenParamIndexOp::getAttributes() { + llvm::SmallVector attrs; + attrs.push_back(getFieldIdAttr()); + attrs.push_back(getOnTypeAttr()); + return attrs; } //===----------------------------------------------------------------------===// diff --git a/flang/test/Fir/Todo/coordinate_of_1.fir b/flang/test/Fir/Todo/coordinate_of_1.fir deleted file mode 100644 --- a/flang/test/Fir/Todo/coordinate_of_1.fir +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: %not_todo_cmd fir-opt --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s 2>&1 | FileCheck %s - -// `fir.coordinate_of` - derived type with `fir.len_param_index`. As -// `fir.len_param_index` is not implemented yet, that's the error that's -// currently being generated (this error is generated before trying to convert -// `fir.coordinate_of`) -func.func @coordinate_box_derived_with_fir_len(%arg0: !fir.box>) { -// CHECK: not yet implemented: fir.len_param_index codegen - %e = fir.len_param_index len1, !fir.type - %q = fir.coordinate_of %arg0, %e : (!fir.box>, !fir.len) -> !fir.ref - return -} diff --git a/flang/test/Fir/Todo/coordinate_of_5.fir b/flang/test/Fir/Todo/coordinate_of_5.fir deleted file mode 100644 --- a/flang/test/Fir/Todo/coordinate_of_5.fir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: %not_todo_cmd fir-opt --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s 2>&1 | FileCheck %s - -// CHECK: unsupported combination of coordinate operands -func.func @test_coordinate_of(%arr : !fir.ref>>, %arg1: index) { - %1 = arith.constant 10 : i32 - %2 = fir.coordinate_of %arr, %arg1, %1 : (!fir.ref>>, index, i32) -> !fir.ref> - return -} diff --git a/flang/test/Fir/Todo/coordinate_of_6.fir b/flang/test/Fir/Todo/coordinate_of_6.fir deleted file mode 100644 --- a/flang/test/Fir/Todo/coordinate_of_6.fir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: %not_todo_cmd fir-opt --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s 2>&1 | FileCheck %s - -// CHECK: unsupported combination of coordinate operands - -func.func @test_coordinate_of(%arr : !fir.ref>, %arg1: index) { - %2 = fir.coordinate_of %arr, %arg1, %arg1 : (!fir.ref>, index, index) -> !fir.ref - return -} diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -2642,13 +2642,13 @@ // 5.3 `fir.char` func.func @test_coordinate_of_char(%arr : !fir.ref>) { %1 = arith.constant 10 : i32 - %2 = fir.coordinate_of %arr, %1 : (!fir.ref>, i32) -> !fir.ref + %2 = fir.coordinate_of %arr, %1 : (!fir.ref>, i32) -> !fir.ref> return } // CHECK-LABEL: llvm.func @test_coordinate_of_char( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>) { // CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(10 : i32) : i32 -// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr>, i32) -> !llvm.ptr +// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr>, i32) -> !llvm.ptr> // CHECK: llvm.return // CHECK: } diff --git a/flang/test/Fir/coordinate_of_1.fir b/flang/test/Fir/coordinate_of_1.fir new file mode 100644 --- /dev/null +++ b/flang/test/Fir/coordinate_of_1.fir @@ -0,0 +1,49 @@ +// RUN: fir-opt --split-input-file --verify-diagnostics %s + +func.func @_QPcoordinate_box_derived_with_fir_len(%arg0: !fir.box>) -> i32 { + %lp = arith.constant 22 : i32 + %e = fir.len_param_index len1, !fir.type(%lp : i32) + // expected-error@+1 {{'fir.coordinate_of' op len_param_index type not compatible with reference type}} + %q = fir.coordinate_of %arg0, %e : (!fir.box>, !fir.len) -> !fir.ref + %val = fir.load %q : !fir.ref + return %val : i32 +} + +// ----- + +func.func @_QPcoordinate_box_derived_with_fir_len2(%arg0: !fir.box>>) -> i32 { + %lp = arith.constant 22 : i32 + %e = fir.len_param_index len1, !fir.type(%lp : i32) + // expected-error@+1 {{'fir.coordinate_of' op too many operands for len_param_index case}} + %q = fir.coordinate_of %arg0, %lp, %e : (!fir.box>>, i32, !fir.len) -> !fir.ref + %val = fir.load %q : !fir.ref + return %val : i32 +} + +// ----- + +func.func @_QPcoordinate_box_derived_with_fir_len3(%arg0: !fir.box>) -> i32 { + %lp = arith.constant 22 : i32 + %e = fir.len_param_index len1, !fir.type(%lp : i32) + // expected-error@+1 {{'fir.coordinate_of' op len_param_index must be last argument}} + %q = fir.coordinate_of %arg0, %e, %e : (!fir.box>, !fir.len, !fir.len) -> !fir.ref + %val = fir.load %q : !fir.ref + return %val : i32 +} + +// ----- + +func.func @_QPtest_coordinate_of(%arr : !fir.ref>>, %arg1: index) { + %1 = arith.constant 10 : i32 + // expected-error@+1 {{'fir.coordinate_of' op character type mismatch}} + %2 = fir.coordinate_of %arr, %arg1, %1 : (!fir.ref>>, index, i32) -> !fir.ref> + return +} + +// ----- + +func.func @_QPtest_coordinate_of(%arr : !fir.ref>, %arg1: index) { + // expected-error@+1 {{'fir.coordinate_of' op invalid parameters (too many)}} + %2 = fir.coordinate_of %arr, %arg1, %arg1 : (!fir.ref>, index, index) -> !fir.ref + return +} diff --git a/flang/test/Fir/coordinateof.fir b/flang/test/Fir/coordinateof.fir new file mode 100644 --- /dev/null +++ b/flang/test/Fir/coordinateof.fir @@ -0,0 +1,80 @@ +// RUN: fir-opt %s | tco | FileCheck %s + +// tests on coordinate_of op + +// CHECK-LABEL: @foo1 +func.func @foo1(%i : i32, %j : i32, %k : i32) -> !fir.ref { + %1 = fir.alloca !fir.array<10 x 20 x 30 x f32> + // CHECK: %[[alloca:.*]] = alloca [30 x [20 x [10 x float]]] + %2 = fir.convert %1 : (!fir.ref>) -> !fir.ref> + // CHECK: getelementptr [20 x [10 x float]], ptr %[[alloca]] + %3 = fir.coordinate_of %2, %i, %j, %k : (!fir.ref>, i32, i32, i32) -> !fir.ref + return %3 : !fir.ref +} + +// CHECK-LABEL: @foo2 +func.func @foo2(%i : i32, %j : i32, %k : i32) -> !fir.ref { + %1 = fir.alloca !fir.array<10 x 20 x 30 x f32> + // CHECK: %[[alloca:.*]] = alloca [30 x [20 x [10 x float]]] + %2 = fir.convert %1 : (!fir.ref>) -> !fir.ref> + // CHECK: getelementptr float, ptr %[[alloca]] + %3 = fir.coordinate_of %2, %i : (!fir.ref>, i32) -> !fir.ref + return %3 : !fir.ref +} + +// CHECK-LABEL: @foo3 +func.func @foo3(%box : !fir.box>, %i : i32) -> i32 { + // CHECK: %[[cvt:.*]] = sext i32 % + %ii = fir.convert %i : (i32) -> index + // CHECK: %[[gep0:.*]] = getelementptr { ptr + // CHECK: %[[boxptr:.*]] = load ptr, ptr %[[gep0]] + // CHECK: %[[gep1:.*]] = getelementptr { ptr, i64, {{.*}} i32 7 + // CHECK: %[[stride:.*]] = load i64, ptr %[[gep1]] + // CHECK: %[[dimoffset:.*]] = mul i64 %[[cvt]], %[[stride]] + // CHECK: %[[offset:.*]] = add i64 %[[dimoffset]], 0 + // CHECK: %[[gep2:.*]] = getelementptr i8, ptr %[[boxptr]], i64 %[[offset]] + %1 = fir.coordinate_of %box, %ii : (!fir.box>, index) -> !fir.ref + // CHECK: load i32, ptr %[[gep2]] + %rv = fir.load %1 : !fir.ref + return %rv : i32 +} + +// CHECK-LABEL: @foo4 +func.func @foo4(%a : !fir.ptr>, %i : i32, %j : i64, %k : index) -> i32 { + // CHECK: getelementptr [25 x [15 x [5 x + %1 = fir.coordinate_of %a, %k : (!fir.ptr>, index) -> !fir.ref> + // CHECK: getelementptr [15 x [5 x + %2 = fir.coordinate_of %1, %j : (!fir.ref>, i64) -> !fir.ref> + // CHECK: %[[ref:.*]] = getelementptr [5 x + %3 = fir.coordinate_of %2, %i : (!fir.ref>, i32) -> !fir.ref + // CHECK: load i32, ptr %[[ref]] + %4 = fir.load %3 : !fir.ref + return %4 : i32 +} + +// CHECK-LABEL: @foo5 +func.func @foo5(%box : !fir.box>>, %i : index) -> i32 { + // similar to foo3 test. Just check that the ptr type is not disturbing codegen. + %1 = fir.coordinate_of %box, %i : (!fir.box>>, index) -> !fir.ref + // CHECK: load i32, ptr %{{.*}} + %rv = fir.load %1 : !fir.ref + return %rv : i32 +} + +// CHECK-LABEL: @foo6 +// CHECK-SAME: (ptr %[[box:.*]], i64 %{{.*}}, ptr %{{.*}}) +func.func @foo6(%box : !fir.box>>>, %i : i64 , %res : !fir.ref>) { + // CHECK: %[[addr_gep:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[box]], i32 0, i32 0 + // CHECK: %[[addr:.*]] = load ptr, ptr %[[addr_gep]] + // CHECK: %[[stride_gep:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[box]], i32 0, i32 7, i64 0, i32 2 + // CHECK: %[[stride:.*]] = load i64, ptr %[[stride_gep]] + // CHECK: %[[mul:.*]] = mul i64 %{{.*}}, %[[stride]] + // CHECK: %[[offset:.*]] = add i64 %[[mul]], 0 + // CHECK: %[[gep:.*]] = getelementptr i8, ptr %[[addr]], i64 %[[offset]] + %coor = fir.coordinate_of %box, %i : (!fir.box>>>, i64) -> !fir.ref> + + // CHECK: load [1 x i8], ptr %[[gep]] + %load = fir.load %coor : !fir.ref> + fir.store %load to %res : !fir.ref> + return +} diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -287,17 +287,17 @@ func.func @test_coordinate_of(%arr : !fir.ref>) { %1 = arith.constant 10 : i32 - // expected-error@+1 {{'fir.coordinate_of' op cannot apply coordinate_of to this type}} - %2 = fir.coordinate_of %arr, %1 : (!fir.ref>, i32) -> !fir.ref + // expected-error@+1 {{'fir.coordinate_of' op cannot apply to this element type}} + %2 = fir.coordinate_of %arr, %1 : (!fir.ref>, i32) -> !fir.ref> return } // ----- -func.func @test_coordinate_of(%arr : !fir.ref>) { +func.func @test_coordinate_of(%arr : !fir.ref>) { %1 = arith.constant 10 : i32 - // expected-error@+1 {{'fir.coordinate_of' op cannot apply coordinate_of to this type}} - %2 = fir.coordinate_of %arr, %1 : (!fir.ref>, i32) -> !fir.ref + // expected-error@+1 {{'fir.coordinate_of' op cannot apply to character singleton}} + %2 = fir.coordinate_of %arr, %1, %1 : (!fir.ref>, i32, i32) -> !fir.ref> return }