diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -2608,16 +2608,14 @@ return success(); } - // Box type - get the base pointer from the box - if (auto boxTy = baseObjectTy.dyn_cast()) { - doRewriteBox(coor, ty, operands, loc, rewriter); - return success(); + // Boxed type - get the base pointer from the box + if (baseObjectTy.dyn_cast()) { + return doRewriteBox(coor, ty, operands, loc, rewriter); } - // Sequence type (e.g. fir.array) - if (auto arrTy = objectTy.dyn_cast()) { - doRewriteSequence(loc); - return success(); + // Reference type + if (baseObjectTy.dyn_cast()) { + return doRewriteRef(coor, ty, operands, loc, rewriter); } return rewriter.notifyMatchFailure( @@ -2644,10 +2642,66 @@ fir::emitFatalError(val.getLoc(), "must be a constant"); } + bool hasSubDimensions(mlir::Type type) const { + return type.isa() || type.isa() || + type.isa(); + } + + bool validCoordinate(mlir::Type type, mlir::ValueRange coors) const { + const auto sz = coors.size(); + std::remove_const_t i = 0; + bool subEle = false; + bool ptrEle = false; + for (; i < sz; ++i) { + auto nxtOpnd = coors[i]; + if (auto arrTy = type.dyn_cast()) { + subEle = true; + i += arrTy.getDimension() - 1; + type = arrTy.getEleTy(); + } else if (auto strTy = type.dyn_cast()) { + subEle = true; + type = strTy.getType(getFieldNumber(strTy, nxtOpnd)); + } else if (auto strTy = type.dyn_cast()) { + subEle = true; + type = strTy.getType(getIntValue(nxtOpnd)); + } else { + ptrEle = true; + } + } + if (ptrEle) + return (!subEle) && (sz == 1); + return subEle && (i >= sz); + } + + /// Walk the abstract memory layout and determine if the path traverses any + /// array types with unknown shape. Return true iff all the array types have a + /// constant shape along the path. + bool arraysHaveKnownShape(mlir::Type type, mlir::ValueRange coors) const { + const auto sz = coors.size(); + std::remove_const_t i = 0; + for (; i < sz; ++i) { + auto nxtOpnd = coors[i]; + if (auto arrTy = type.dyn_cast()) { + if (fir::sequenceWithNonConstantShape(arrTy)) + return false; + i += arrTy.getDimension() - 1; + type = arrTy.getEleTy(); + } else if (auto strTy = type.dyn_cast()) { + type = strTy.getType(getFieldNumber(strTy, nxtOpnd)); + } else if (auto strTy = type.dyn_cast()) { + type = strTy.getType(getIntValue(nxtOpnd)); + } else { + return true; + } + } + return true; + } + private: - void doRewriteBox(fir::CoordinateOp coor, mlir::Type ty, - mlir::ValueRange operands, mlir::Location loc, - mlir::ConversionPatternRewriter &rewriter) const { + mlir::LogicalResult + doRewriteBox(fir::CoordinateOp coor, mlir::Type ty, mlir::ValueRange operands, + mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter) const { mlir::Type boxObjTy = coor.getBaseType(); assert(boxObjTy.dyn_cast() && "This is not a `fir.box`"); @@ -2732,11 +2786,118 @@ } rewriter.replaceOpWithNewOp(coor, ty, resultAddr); - return; + return success(); } - void doRewriteSequence(mlir::Location loc) const { - TODO(loc, "fir.coordinate_of codegen for sequence types"); + mlir::LogicalResult + doRewriteRef(fir::CoordinateOp coor, mlir::Type ty, mlir::ValueRange operands, + mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Type baseObjectTy = coor.getBaseType(); + + auto currentObjTy = fir::dyn_cast_ptrOrBoxEleTy(baseObjectTy); + assert(currentObjTy.dyn_cast() || + currentObjTy.dyn_cast() && "Unsupported type"); + bool hasSubdimension = hasSubDimensions(currentObjTy); + bool columnIsDeferred = false; + + if (!hasSubdimension) + columnIsDeferred = true; + + if (!validCoordinate(currentObjTy, operands.drop_front(1))) + TODO(loc, "coordinate has incorrect dimension"); + + // if arrays has known shape + const bool hasKnownShape = + arraysHaveKnownShape(currentObjTy, operands.drop_front(1)); + + // If only the column is `?`, then we can simply place the column value in + // the 0-th GEP position. + if (auto arrTy = currentObjTy.dyn_cast()) { + if (!hasKnownShape) { + const auto sz = arrTy.getDimension(); + if (arraysHaveKnownShape(arrTy.getEleTy(), + operands.drop_front(1 + sz))) { + auto shape = arrTy.getShape(); + bool allConst = true; + for (std::remove_const_t i = 0; i < sz - 1; ++i) + if (shape[i] < 0) { + allConst = false; + break; + } + if (allConst) + columnIsDeferred = true; + } + } + } + + if (fir::hasDynamicSize(fir::unwrapSequenceType(currentObjTy))) { + mlir::emitError( + loc, "fir.coordinate_of with a dynamic element size is unsupported"); + return failure(); + } + + if (hasKnownShape || columnIsDeferred) { + SmallVector offs; + if (hasKnownShape && hasSubdimension) { + mlir::LLVM::ConstantOp c0 = + genConstantIndex(loc, lowerTy().indexType(), rewriter, 0); + offs.push_back(c0); + } + const auto sz = operands.size(); + Optional dims; + SmallVector arrIdx; + for (std::remove_const_t i = 1; i < sz; ++i) { + auto nxtOpnd = operands[i]; + + if (!currentObjTy) + TODO(loc, "invalid coordinate/check failed"); + + // check if the i-th coordinate relates to an array + if (dims.hasValue()) { + arrIdx.push_back(nxtOpnd); + int dimsLeft = *dims; + if (dimsLeft > 1) { + dims = dimsLeft - 1; + continue; + } + currentObjTy = currentObjTy.cast().getEleTy(); + // append array range in reverse (FIR arrays are column-major) + offs.append(arrIdx.rbegin(), arrIdx.rend()); + arrIdx.clear(); + dims.reset(); + continue; + } + if (auto arrTy = currentObjTy.dyn_cast()) { + int d = arrTy.getDimension() - 1; + if (d > 0) { + dims = d; + arrIdx.push_back(nxtOpnd); + continue; + } + currentObjTy = currentObjTy.cast().getEleTy(); + offs.push_back(nxtOpnd); + continue; + } + + // check if the i-th coordinate relates to a field + if (auto strTy = currentObjTy.dyn_cast()) { + currentObjTy = strTy.getType(getFieldNumber(strTy, nxtOpnd)); + } else if (auto strTy = currentObjTy.dyn_cast()) { + currentObjTy = strTy.getType(getIntValue(nxtOpnd)); + } else { + currentObjTy = nullptr; + } + offs.push_back(nxtOpnd); + } + if (dims.hasValue()) + offs.append(arrIdx.rbegin(), arrIdx.rend()); + mlir::Value base = operands[0]; + mlir::Value retval = genGEP(loc, ty, rewriter, base, offs); + rewriter.replaceOp(coor, retval); + return success(); + } + TODO(loc, "fir.coordinate_of base operand has unsupported type"); } }; diff --git a/flang/test/Fir/convert-to-llvm-invalid.fir b/flang/test/Fir/convert-to-llvm-invalid.fir --- a/flang/test/Fir/convert-to-llvm-invalid.fir +++ b/flang/test/Fir/convert-to-llvm-invalid.fir @@ -88,3 +88,13 @@ %zero = arith.constant 0 : i32 return %zero : i32 } + +// ----- + +// `fir.coordinate_of` - dynamically sized arrays are not supported +func @coordinate_of_dynamic_array(%arg0: !fir.ref>>, %arg1: index) { + // expected-error@+2{{fir.coordinate_of with a dynamic element size is unsupported}} + // expected-error@+1{{failed to legalize operation 'fir.coordinate_of'}} + %p = fir.coordinate_of %arg0, %arg1 : (!fir.ref>>, index) -> !fir.ref + return +} diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -1956,7 +1956,7 @@ // ----- -// Test `fir.coordinate_of` conversion +// Test `fir.coordinate_of` conversion (items inside `!fir.box`) // 1. COMPLEX TYPE (`fir.complex` is a special case) // Complex type wrapped in `fir.ref` @@ -1985,7 +1985,7 @@ // ----- -// Test `fir.coordinate_of` conversion +// Test `fir.coordinate_of` conversion (items inside `!fir.box`) // 2. BOX TYPE (objects wrapped in `fir.box`) // Derived type - basic case (1 index) @@ -2038,7 +2038,7 @@ // ----- -// Test `fir.coordinate_of` conversion +// Test `fir.coordinate_of` conversion (items inside `!fir.box`) // 3. BOX TYPE - `fir.array` wrapped in `fir.box` // `fir.array` inside a `fir.box` (1d) @@ -2152,7 +2152,7 @@ // ----- -// Test `fir.coordinate_of` conversion +// Test `fir.coordinate_of` conversion (items inside `!fir.box`) // 4. BOX TYPE - `fir.derived` inside `fir.array` func @coordinate_box_derived_inside_array(%arg0: !fir.box>>, %arg1 : index) { @@ -2185,3 +2185,72 @@ // CHECK: %[[VAL_21:.*]] = llvm.bitcast %[[VAL_20]] : !llvm.ptr to !llvm.ptr // CHECK: %[[VAL_22:.*]] = llvm.bitcast %[[VAL_21]] : !llvm.ptr to !llvm.ptr // CHECK: llvm.return + +// ----- + +// Test `fir.coordinate_of` conversion (items inside `!fir.ref`) + +// 5.1. `fir.array` +func @coordinate_array_unknown_size_1d(%arg0: !fir.ref>, %arg1 : index) { + %q = fir.coordinate_of %arg0, %arg1 : (!fir.ref>, index) -> !fir.ref + return +} +// CHECK-LABEL: llvm.func @coordinate_array_unknown_size_1d( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) { +// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK: llvm.return +// CHECK: } + +func @coordinate_array_known_size_1d(%arg0: !fir.ref>, %arg1 : index) { + %q = fir.coordinate_of %arg0, %arg1 : (!fir.ref>, index) -> !fir.ref + return +} +// CHECK-LABEL: llvm.func @coordinate_array_known_size_1d( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>, +// CHECK-SAME: %[[VAL_1:.*]]: i64) { +// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr>, i64, i64) -> !llvm.ptr +// CHECK: llvm.return +// CHECK: } + +func @coordinate_array_known_size_2d(%arg0: !fir.ref>, %arg1 : index) { + %q = fir.coordinate_of %arg0, %arg1, %arg1 : (!fir.ref>, index, index) -> !fir.ref + return +} +// CHECK-LABEL: llvm.func @coordinate_array_known_size_2d( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>>, +// CHECK-SAME: %[[VAL_1:.*]]: i64) { +// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr>>, i64, i64, i64) -> !llvm.ptr +// CHECK: llvm.return +// CHECK: } + +// 5.2. `fir.derived` +func @coordinate_ref_derived(%arg0: !fir.ref>) { + %idx = fir.field_index field_2, !fir.type + %q = fir.coordinate_of %arg0, %idx : (!fir.ref>, !fir.field) -> !fir.ref + return +} +// CHECK-LABEL: llvm.func @coordinate_ref_derived( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>) { +// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr>, i64, i32) -> !llvm.ptr +// CHECK: llvm.return +// CHECK: } + +func @coordinate_ref_derived_nested(%arg0: !fir.ref, field_2:i32}>>) { + %idx0 = fir.field_index field_1, !fir.type, field_2:i32}> + %idx1 = fir.field_index inner2, !fir.type + %q = fir.coordinate_of %arg0, %idx0, %idx1 : (!fir.ref, field_2:i32}>>, !fir.field, !fir.field) -> !fir.ref + return +} +// CHECK-LABEL: llvm.func @coordinate_ref_derived_nested( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, i32)>>) { +// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_1]], %[[VAL_2]]] : (!llvm.ptr, i32)>>, i64, i32, i32) -> !llvm.ptr +// CHECK: llvm.return +// CHECK: }