diff --git a/flang/lib/Optimizer/CodeGen/TBAABuilder.cpp b/flang/lib/Optimizer/CodeGen/TBAABuilder.cpp --- a/flang/lib/Optimizer/CodeGen/TBAABuilder.cpp +++ b/flang/lib/Optimizer/CodeGen/TBAABuilder.cpp @@ -12,6 +12,7 @@ #include "TBAABuilder.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -159,9 +160,13 @@ else tbaaTagSym = getDataAccessTag(baseFIRType, accessFIRType, gep); - if (tbaaTagSym) - op->setAttr(LLVMDialect::getTBAAAttrName(), - ArrayAttr::get(op->getContext(), tbaaTagSym)); + if (!tbaaTagSym) + return; + + auto tbaaAttr = ArrayAttr::get(op->getContext(), tbaaTagSym); + llvm::TypeSwitch(op) + .Case([&](auto memOp) { memOp.setTbaaAttr(tbaaAttr); }) + .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); }); } } // namespace fir diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir --- a/flang/test/Fir/tbaa.fir +++ b/flang/test/Fir/tbaa.fir @@ -28,10 +28,10 @@ // CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(10 : i32) : i32 // CHECK: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]][0, 0] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>) -> !llvm.ptr>> -// CHECK: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr>> +// CHECK: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr>> // CHECK: %[[VAL_8:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[VAL_9:.*]] = llvm.getelementptr %[[VAL_0]][0, 7, 0, 2] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>) -> !llvm.ptr -// CHECK: %[[VAL_10:.*]] = llvm.load %[[VAL_9]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr +// CHECK: %[[VAL_10:.*]] = llvm.load %[[VAL_9]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr // CHECK: %[[VAL_11:.*]] = llvm.mul %[[VAL_4]], %[[VAL_10]] : i64 // CHECK: %[[VAL_12:.*]] = llvm.add %[[VAL_11]], %[[VAL_8]] : i64 // CHECK: %[[VAL_13:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr> to !llvm.ptr @@ -40,11 +40,11 @@ // CHECK: %[[VAL_16:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[VAL_17:.*]] = llvm.mlir.constant(-1 : i32) : i32 // CHECK: %[[VAL_18:.*]] = llvm.getelementptr %[[VAL_0]][0, 8] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>) -> !llvm.ptr> -// CHECK: %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr> +// CHECK: %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr> // CHECK: %[[VAL_20:.*]] = llvm.getelementptr %[[VAL_0]][0, 1] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>) -> !llvm.ptr -// CHECK: %[[VAL_21:.*]] = llvm.load %[[VAL_20]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr +// CHECK: %[[VAL_21:.*]] = llvm.load %[[VAL_20]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr // CHECK: %[[VAL_22:.*]] = llvm.getelementptr %[[VAL_0]][0, 4] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>) -> !llvm.ptr -// CHECK: %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr +// CHECK: %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr // CHECK: %[[VAL_24:.*]] = llvm.mlir.undef : !llvm.struct<(ptr>, i64, i32, i8, i8, i8, i8, ptr, array<1 x i64>)> // CHECK: %[[VAL_25:.*]] = llvm.insertvalue %[[VAL_21]], %[[VAL_24]][1] : !llvm.struct<(ptr>, i64, i32, i8, i8, i8, i8, ptr, array<1 x i64>)> // CHECK: %[[VAL_26:.*]] = llvm.mlir.constant(20180515 : i32) : i32 @@ -64,15 +64,15 @@ // CHECK: %[[VAL_40:.*]] = llvm.insertvalue %[[VAL_39]], %[[VAL_38]][7] : !llvm.struct<(ptr>, i64, i32, i8, i8, i8, i8, ptr, array<1 x i64>)> // CHECK: %[[VAL_41:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr> to !llvm.ptr> // CHECK: %[[VAL_42:.*]] = llvm.insertvalue %[[VAL_41]], %[[VAL_40]][0] : !llvm.struct<(ptr>, i64, i32, i8, i8, i8, i8, ptr, array<1 x i64>)> -// CHECK: llvm.store %[[VAL_42]], %[[VAL_2]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr>, i64, i32, i8, i8, i8, i8, ptr, array<1 x i64>)>> +// CHECK: llvm.store %[[VAL_42]], %[[VAL_2]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr>, i64, i32, i8, i8, i8, i8, ptr, array<1 x i64>)>> // CHECK: %[[VAL_43:.*]] = llvm.getelementptr %[[VAL_2]][0, 4] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, ptr, array<1 x i64>)>>) -> !llvm.ptr -// CHECK: %[[VAL_44:.*]] = llvm.load %[[VAL_43]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr +// CHECK: %[[VAL_44:.*]] = llvm.load %[[VAL_43]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr // CHECK: %[[VAL_45:.*]] = llvm.icmp "eq" %[[VAL_44]], %[[VAL_3]] : i8 // CHECK: llvm.cond_br %[[VAL_45]], ^bb1, ^bb2 // CHECK: ^bb1: // CHECK: %[[VAL_46:.*]] = llvm.getelementptr %[[VAL_2]][0, 0] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, ptr, array<1 x i64>)>>) -> !llvm.ptr> -// CHECK: %[[VAL_47:.*]] = llvm.load %[[VAL_46]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr> -// CHECK: llvm.store %[[VAL_5]], %[[VAL_47]] {llvm.tbaa = [@__flang_tbaa::@[[DATAT:tag_[0-9]*]]]} : !llvm.ptr +// CHECK: %[[VAL_47:.*]] = llvm.load %[[VAL_46]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr> +// CHECK: llvm.store %[[VAL_5]], %[[VAL_47]] {tbaa = [@__flang_tbaa::@[[DATAT:tag_[0-9]*]]]} : !llvm.ptr // CHECK: llvm.br ^bb2 // CHECK: ^bb2: // CHECK: llvm.return @@ -133,24 +133,24 @@ // CHECK: %[[VAL_8:.*]] = llvm.mlir.addressof @_QQcl.2E2F64756D6D792E66393000 : !llvm.ptr> // CHECK: %[[VAL_9:.*]] = llvm.bitcast %[[VAL_8]] : !llvm.ptr> to !llvm.ptr // CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_9]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath} : (i32, !llvm.ptr, i32) -> !llvm.ptr -// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_7]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>> -// CHECK: llvm.store %[[VAL_11]], %[[VAL_3]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>> +// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_7]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>> +// CHECK: llvm.store %[[VAL_11]], %[[VAL_3]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>> // CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>, i64) -> !llvm.ptr -// CHECK: %[[VAL_13:.*]] = llvm.load %[[VAL_12]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr +// CHECK: %[[VAL_13:.*]] = llvm.load %[[VAL_12]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr // CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 1] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>, i64) -> !llvm.ptr -// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr +// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr // CHECK: %[[VAL_16:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 2] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>, i64) -> !llvm.ptr -// CHECK: %[[VAL_17:.*]] = llvm.load %[[VAL_16]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr +// CHECK: %[[VAL_17:.*]] = llvm.load %[[VAL_16]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr // CHECK: %[[VAL_18:.*]] = llvm.getelementptr %[[VAL_3]][0, 8] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>) -> !llvm.ptr> -// CHECK: %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr> +// CHECK: %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr> // CHECK: %[[VAL_20:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[VAL_21:.*]] = llvm.mlir.constant(-1 : i32) : i32 // CHECK: %[[VAL_22:.*]] = llvm.getelementptr %[[VAL_3]][0, 1] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>) -> !llvm.ptr -// CHECK: %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr +// CHECK: %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr // CHECK: %[[VAL_24:.*]] = llvm.getelementptr %[[VAL_3]][0, 4] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>) -> !llvm.ptr -// CHECK: %[[VAL_25:.*]] = llvm.load %[[VAL_24]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr +// CHECK: %[[VAL_25:.*]] = llvm.load %[[VAL_24]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr // CHECK: %[[VAL_26:.*]] = llvm.getelementptr %[[VAL_3]][0, 8] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>) -> !llvm.ptr> -// CHECK: %[[VAL_27:.*]] = llvm.load %[[VAL_26]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr> +// CHECK: %[[VAL_27:.*]] = llvm.load %[[VAL_26]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr> // CHECK: %[[VAL_28:.*]] = llvm.mlir.undef : !llvm.struct<(ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)> // CHECK: %[[VAL_29:.*]] = llvm.insertvalue %[[VAL_23]], %[[VAL_28]][1] : !llvm.struct<(ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)> // CHECK: %[[VAL_30:.*]] = llvm.mlir.constant(20180515 : i32) : i32 @@ -169,13 +169,13 @@ // CHECK: %[[VAL_43:.*]] = llvm.bitcast %[[VAL_27]] : !llvm.ptr to !llvm.ptr // CHECK: %[[VAL_44:.*]] = llvm.insertvalue %[[VAL_43]], %[[VAL_42]][8] : !llvm.struct<(ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)> // CHECK: %[[VAL_45:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, 0, 0] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>) -> !llvm.ptr -// CHECK: %[[VAL_46:.*]] = llvm.load %[[VAL_45]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr +// CHECK: %[[VAL_46:.*]] = llvm.load %[[VAL_45]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr // CHECK: %[[VAL_47:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, 0, 1] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>) -> !llvm.ptr -// CHECK: %[[VAL_48:.*]] = llvm.load %[[VAL_47]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr +// CHECK: %[[VAL_48:.*]] = llvm.load %[[VAL_47]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr // CHECK: %[[VAL_49:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, 0, 2] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>) -> !llvm.ptr -// CHECK: %[[VAL_50:.*]] = llvm.load %[[VAL_49]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr +// CHECK: %[[VAL_50:.*]] = llvm.load %[[VAL_49]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr // CHECK: %[[VAL_51:.*]] = llvm.getelementptr %[[VAL_3]][0, 0] : (!llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>>) -> !llvm.ptr>> -// CHECK: %[[VAL_52:.*]] = llvm.load %[[VAL_51]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr>> +// CHECK: %[[VAL_52:.*]] = llvm.load %[[VAL_51]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr>> // CHECK: %[[VAL_53:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[VAL_54:.*]] = llvm.mlir.constant(1 : i64) : i64 // CHECK: %[[VAL_55:.*]] = llvm.icmp "eq" %[[VAL_48]], %[[VAL_53]] : i64 @@ -185,7 +185,7 @@ // CHECK: %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)> // CHECK: %[[VAL_60:.*]] = llvm.bitcast %[[VAL_52]] : !llvm.ptr> to !llvm.ptr> // CHECK: %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_60]], %[[VAL_59]][0] : !llvm.struct<(ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)> -// CHECK: llvm.store %[[VAL_61]], %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>> +// CHECK: llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>> // CHECK: %[[VAL_62:.*]] = llvm.bitcast %[[VAL_1]] : !llvm.ptr>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>> to !llvm.ptr>, i64, i32, i8, i8, i8, i8, ptr, array<1 x i64>)>> // CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_62]]) {fastmathFlags = #llvm.fastmath} : (!llvm.ptr, !llvm.ptr>, i64, i32, i8, i8, i8, i8, ptr, array<1 x i64>)>>) -> i1 // CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath} : (!llvm.ptr) -> i32 @@ -253,7 +253,7 @@ // CHECK-LABEL: llvm.func @tbaa( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, i64, i32, i8, i8, i8, i8)>>) -> i32 { // CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr -// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr +// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr // CHECK: llvm.return %[[VAL_2]] : i32 // CHECK: } @@ -275,7 +275,7 @@ // CHECK-LABEL: llvm.func @tbaa( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, i64, i32, i8, i8, i8, i8)>>) -> i1 { // CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr -// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr +// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr // CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAL_4:.*]] = llvm.icmp "ne" %[[VAL_2]], %[[VAL_3]] : i32 // CHECK: llvm.return %[[VAL_4]] : i1 @@ -299,7 +299,7 @@ // CHECK-LABEL: llvm.func @tbaa( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, i64, i32, i8, i8, i8, i8)>>) -> i32 { // CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 1] : (!llvm.ptr, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr -// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr +// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr // CHECK: llvm.return %[[VAL_2]] : i32 // CHECK: } @@ -321,7 +321,7 @@ // CHECK-LABEL: llvm.func @tbaa( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, i64, i32, i8, i8, i8, i8)>>) -> i1 { // CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 5] : (!llvm.ptr, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr -// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr +// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr // CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: %[[VAL_4:.*]] = llvm.and %[[VAL_2]], %[[VAL_3]] : i32 // CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32 @@ -353,11 +353,11 @@ // CHECK: %[[VAL_4:.*]] = llvm.sub %[[VAL_1]], %[[VAL_2]] : i64 // CHECK: %[[VAL_5:.*]] = llvm.mul %[[VAL_4]], %[[VAL_2]] : i64 // CHECK: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]][0, 7, 0, 2] : (!llvm.ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>>) -> !llvm.ptr -// CHECK: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr +// CHECK: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr // CHECK: %[[VAL_8:.*]] = llvm.mul %[[VAL_5]], %[[VAL_7]] : i64 // CHECK: %[[VAL_9:.*]] = llvm.add %[[VAL_8]], %[[VAL_3]] : i64 // CHECK: %[[VAL_10:.*]] = llvm.getelementptr %[[VAL_0]][0, 0] : (!llvm.ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>>) -> !llvm.ptr> -// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr> +// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr> // CHECK: %[[VAL_12:.*]] = llvm.bitcast %[[VAL_11]] : !llvm.ptr to !llvm.ptr // CHECK: %[[VAL_13:.*]] = llvm.getelementptr %[[VAL_12]]{{\[}}%[[VAL_9]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: %[[VAL_14:.*]] = llvm.bitcast %[[VAL_13]] : !llvm.ptr to !llvm.ptr diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -35,11 +35,7 @@ let extraClassDeclaration = [{ /// Name of the data layout attributes. static StringRef getDataLayoutAttrName() { return "llvm.data_layout"; } - static StringRef getNoAliasScopesAttrName() { return "noalias_scopes"; } - static StringRef getAliasScopesAttrName() { return "alias_scopes"; } static StringRef getLoopAttrName() { return "llvm.loop"; } - static StringRef getAccessGroupsAttrName() { return "access_groups"; } - static StringRef getTBAAAttrName() { return "llvm.tbaa"; } /// Names of llvm parameter attributes. static StringRef getAlignAttrName() { return "llvm.align"; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -350,6 +350,7 @@ OptionalAttr:$access_groups, OptionalAttr:$alias_scopes, OptionalAttr:$noalias_scopes, + OptionalAttr:$tbaa, OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal); let results = (outs LLVM_LoadableType:$res); @@ -390,6 +391,7 @@ OptionalAttr:$access_groups, OptionalAttr:$alias_scopes, OptionalAttr:$noalias_scopes, + OptionalAttr:$tbaa, OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal); string llvmInstName = "Store"; diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -120,15 +120,14 @@ /// in these blocks. void forgetMapping(Region ®ion); - /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM - /// dialect access group operation. - llvm::MDNode *getAccessGroup(Operation &opInst, + /// Returns the LLVM metadata corresponding to a symbol reference to an mlir + /// LLVM dialect access group operation. + llvm::MDNode *getAccessGroup(Operation *op, SymbolRefAttr accessGroupRef) const; - /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM - /// dialect alias scope operation - llvm::MDNode *getAliasScope(Operation &opInst, - SymbolRefAttr aliasScopeRef) const; + /// Returns the LLVM metadata corresponding to a symbol reference to an mlir + /// LLVM dialect alias scope operation + llvm::MDNode *getAliasScope(Operation *op, SymbolRefAttr aliasScopeRef) const; // Sets LLVM metadata for memory operations that are in a parallel loop. void setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst); @@ -287,9 +286,9 @@ /// metadata nodes for them and their domains. LogicalResult createAliasScopeMetadata(); - /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM - /// dialect TBAATagOp operation. - llvm::MDNode *getTBAANode(Operation &memOp, SymbolRefAttr tagRef) const; + /// Returns the LLVM metadata corresponding to a symbol reference to an mlir + /// LLVM dialect TBAATagOp operation. + llvm::MDNode *getTBAANode(Operation *op, SymbolRefAttr tagRef) const; /// Process tbaa LLVM Metadata operations and create LLVM /// metadata nodes for them. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -668,53 +668,60 @@ // Builder, printer and parser for for LLVM::LoadOp. //===----------------------------------------------------------------------===// -LogicalResult verifySymbolAttribute( - Operation *op, StringRef attributeName, +/// Verifies the given array attribute contains symbol references and checks the +/// referenced symbol types using the provided verification function. +LogicalResult verifyMemOpSymbolRefs( + Operation *op, StringRef name, ArrayAttr symbolRefs, llvm::function_ref verifySymbolType) { - if (Attribute attribute = op->getAttr(attributeName)) { - // Verify that the attribute is a symbol ref array attribute, - // because this constraint is not verified for all attribute - // names processed here (e.g. 'tbaa'). This verification - // is redundant in some cases. - if (!(attribute.isa() && - llvm::all_of(attribute.cast(), [&](Attribute attr) { - return attr && attr.isa(); - }))) - return op->emitOpError("attribute '") - << attributeName - << "' failed to satisfy constraint: symbol ref array attribute"; - - for (SymbolRefAttr symbolRef : - attribute.cast().getAsRange()) { - StringAttr metadataName = symbolRef.getRootReference(); - StringAttr symbolName = symbolRef.getLeafReference(); - // We want @metadata::@symbol, not just @symbol - if (metadataName == symbolName) { - return op->emitOpError() << "expected '" << symbolRef - << "' to specify a fully qualified reference"; - } - auto metadataOp = SymbolTable::lookupNearestSymbolFrom( - op->getParentOp(), metadataName); - if (!metadataOp) - return op->emitOpError() - << "expected '" << symbolRef << "' to reference a metadata op"; - Operation *symbolOp = - SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName); - if (!symbolOp) - return op->emitOpError() - << "expected '" << symbolRef << "' to be a valid reference"; - if (failed(verifySymbolType(symbolOp, symbolRef))) { - return failure(); - } + assert(symbolRefs && "expected a non-null attribute"); + + // Verify that the attribute is a symbol ref array attribute, + // because this constraint is not verified for all attribute + // names processed here (e.g. 'tbaa'). This verification + // is redundant in some cases. + if (!llvm::all_of(symbolRefs, [](Attribute attr) { + return attr && attr.isa(); + })) + return op->emitOpError("attribute '") + << name + << "' failed to satisfy constraint: symbol ref array attribute"; + + for (SymbolRefAttr symbolRef : symbolRefs.getAsRange()) { + StringAttr metadataName = symbolRef.getRootReference(); + StringAttr symbolName = symbolRef.getLeafReference(); + // We want @metadata::@symbol, not just @symbol + if (metadataName == symbolName) { + return op->emitOpError() << "expected '" << symbolRef + << "' to specify a fully qualified reference"; + } + auto metadataOp = SymbolTable::lookupNearestSymbolFrom( + op->getParentOp(), metadataName); + if (!metadataOp) + return op->emitOpError() + << "expected '" << symbolRef << "' to reference a metadata op"; + Operation *symbolOp = + SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName); + if (!symbolOp) + return op->emitOpError() + << "expected '" << symbolRef << "' to be a valid reference"; + if (failed(verifySymbolType(symbolOp, symbolRef))) { + return failure(); } } + return success(); } -// Verifies that metadata ops are wired up properly. +/// Verifies the given array attribute contains symbol references that point to +/// metadata operations of the given type. template -static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) { +static LogicalResult +verifyMemOpSymbolRefsPointTo(Operation *op, StringRef name, + std::optional symbolRefs) { + if (!symbolRefs) + return success(); + auto verifySymbolType = [op](Operation *symbolOp, SymbolRefAttr symbolRef) -> LogicalResult { if (!isa(symbolOp)) { @@ -724,35 +731,33 @@ } return success(); }; - - return verifySymbolAttribute(op, attributeName, verifySymbolType); + return verifyMemOpSymbolRefs(op, name, *symbolRefs, verifySymbolType); } -static LogicalResult verifyMemoryOpMetadata(Operation *op) { - // access_groups - if (failed(verifyOpMetadata( - op, LLVMDialect::getAccessGroupsAttrName()))) +/// Verifies the types of the metadata operations referenced by aliasing and +/// access group metadata. +template +LogicalResult verifyMemOpMetadata(OpTy memOp) { + if (failed(verifyMemOpSymbolRefsPointTo( + memOp, memOp.getAccessGroupsAttrName(), memOp.getAccessGroups()))) return failure(); - // alias_scopes - if (failed(verifyOpMetadata( - op, LLVMDialect::getAliasScopesAttrName()))) + if (failed(verifyMemOpSymbolRefsPointTo( + memOp, memOp.getAliasScopesAttrName(), memOp.getAliasScopes()))) return failure(); - // noalias_scopes - if (failed(verifyOpMetadata( - op, LLVMDialect::getNoAliasScopesAttrName()))) + if (failed(verifyMemOpSymbolRefsPointTo( + memOp, memOp.getNoaliasScopesAttrName(), memOp.getNoaliasScopes()))) return failure(); - // tbaa - if (failed(verifyOpMetadata(op, - LLVMDialect::getTBAAAttrName()))) + if (failed(verifyMemOpSymbolRefsPointTo( + memOp, memOp.getTbaaAttrName(), memOp.getTbaa()))) return failure(); return success(); } -LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); } +LogicalResult LoadOp::verify() { return verifyMemOpMetadata(*this); } void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, Value addr, unsigned alignment, bool isVolatile, @@ -828,7 +833,7 @@ // Builder, printer and parser for LLVM::StoreOp. //===----------------------------------------------------------------------===// -LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); } +LogicalResult StoreOp::verify() { return verifyMemOpMetadata(*this); } void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, Value addr, unsigned alignment, bool isVolatile, diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -76,6 +76,30 @@ return convertibleMetadata; } +namespace { +/// Helper class to attach metadata attributes to specific operation types. It +/// specializes TypeSwitch to take an Operation and return a LogicalResult. +template +struct AttributeSetter { + AttributeSetter(Operation *op) : op(op) {} + + /// Calls `attachFn` on the provided Operation if it has one of + /// the given operation types. Returns failure otherwise. + template + LogicalResult apply(CallableT &&attachFn) { + return llvm::TypeSwitch(op) + .Case([&attachFn](auto concreteOp) { + attachFn(concreteOp); + return success(); + }) + .Default([&](auto) { return failure(); }); + } + +private: + Operation *op; +}; +} // namespace + /// Converts the given profiling metadata `node` to an MLIR profiling attribute /// and attaches it to the imported operation if the translation succeeds. /// Returns failure otherwise. @@ -129,16 +153,10 @@ branchWeights.push_back(branchWeight->getZExtValue()); } - // Attach the branch weights to the operations that support it. - return llvm::TypeSwitch(op) - .Case([&](auto branchWeightOp) { + return AttributeSetter(op).apply( + [&](auto branchWeightOp) { branchWeightOp.setBranchWeightsAttr( builder.getI32VectorAttr(branchWeights)); - return success(); - }) - .Default([op](auto) { - return op->emitWarning() - << op->getName() << " does not support branch weights"; }); } @@ -151,9 +169,9 @@ if (!tbaaTagSym) return failure(); - op->setAttr(LLVMDialect::getTBAAAttrName(), - ArrayAttr::get(op->getContext(), tbaaTagSym)); - return success(); + return AttributeSetter(op).apply([&](auto memOp) { + memOp.setTbaaAttr(ArrayAttr::get(memOp.getContext(), tbaaTagSym)); + }); } /// Looks up all the symbol references pointing to the access group operations @@ -169,9 +187,10 @@ SmallVector accessGroupAttrs(accessGroups->begin(), accessGroups->end()); - op->setAttr(LLVMDialect::getAccessGroupsAttrName(), - ArrayAttr::get(op->getContext(), accessGroupAttrs)); - return success(); + return AttributeSetter(op).apply([&](auto memOp) { + memOp.setAccessGroupsAttr( + ArrayAttr::get(memOp.getContext(), accessGroupAttrs)); + }); } /// Converts the given loop metadata node to an MLIR loop annotation attribute diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp --- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp @@ -200,7 +200,7 @@ llvm::MDString::get(ctx, "llvm.loop.parallel_accesses")); for (SymbolRefAttr accessGroupRef : parallelAccessGroups) parallelAccess.push_back( - moduleTranslation.getAccessGroup(*op, accessGroupRef)); + moduleTranslation.getAccessGroup(op, accessGroupRef)); metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess)); } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -986,12 +986,12 @@ } llvm::MDNode * -ModuleTranslation::getAccessGroup(Operation &opInst, +ModuleTranslation::getAccessGroup(Operation *op, SymbolRefAttr accessGroupRef) const { auto metadataName = accessGroupRef.getRootReference(); auto accessGroupName = accessGroupRef.getLeafReference(); auto metadataOp = SymbolTable::lookupNearestSymbolFrom( - opInst.getParentOp(), metadataName); + op->getParentOp(), metadataName); auto *accessGroupOp = SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); return accessGroupMetadataMapping.lookup(accessGroupOp); @@ -1010,23 +1010,28 @@ void ModuleTranslation::setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst) { - auto accessGroups = - op->getAttrOfType(LLVMDialect::getAccessGroupsAttrName()); - if (accessGroups && !accessGroups.empty()) { + auto populateGroupsMetadata = [&](std::optional groupRefs) { + if (!groupRefs || groupRefs->empty()) + return; + llvm::Module *module = inst->getModule(); - SmallVector metadatas; - for (SymbolRefAttr accessGroupRef : - accessGroups.getAsRange()) - metadatas.push_back(getAccessGroup(*op, accessGroupRef)); - - llvm::MDNode *unionMD = nullptr; - if (metadatas.size() == 1) - unionMD = llvm::cast(metadatas.front()); - else if (metadatas.size() >= 2) - unionMD = llvm::MDNode::get(module->getContext(), metadatas); - - inst->setMetadata(module->getMDKindID("llvm.access.group"), unionMD); - } + SmallVector groupMDs; + for (SymbolRefAttr groupRef : groupRefs->getAsRange()) + groupMDs.push_back(getAccessGroup(op, groupRef)); + + llvm::MDNode *node = nullptr; + if (groupMDs.size() == 1) + node = llvm::cast(groupMDs.front()); + else if (groupMDs.size() >= 2) + node = llvm::MDNode::get(module->getContext(), groupMDs); + + inst->setMetadata(llvm::LLVMContext::MD_access_group, node); + }; + + llvm::TypeSwitch(op) + .Case( + [&](auto memOp) { populateGroupsMetadata(memOp.getAccessGroups()); }) + .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); }); } LogicalResult ModuleTranslation::createAliasScopeMetadata() { @@ -1067,12 +1072,12 @@ } llvm::MDNode * -ModuleTranslation::getAliasScope(Operation &opInst, +ModuleTranslation::getAliasScope(Operation *op, SymbolRefAttr aliasScopeRef) const { StringAttr metadataName = aliasScopeRef.getRootReference(); StringAttr scopeName = aliasScopeRef.getLeafReference(); auto metadataOp = SymbolTable::lookupNearestSymbolFrom( - opInst.getParentOp(), metadataName); + op->getParentOp(), metadataName); Operation *aliasScopeOp = SymbolTable::lookupNearestSymbolFrom(metadataOp, scopeName); return aliasScopeMetadataMapping.lookup(aliasScopeOp); @@ -1080,50 +1085,63 @@ void ModuleTranslation::setAliasScopeMetadata(Operation *op, llvm::Instruction *inst) { - auto populateScopeMetadata = [this, op, inst](StringRef attrName, - StringRef llvmMetadataName) { - auto scopes = op->getAttrOfType(attrName); - if (!scopes || scopes.empty()) + auto populateScopeMetadata = [&](std::optional scopeRefs, + unsigned kind) { + if (!scopeRefs || scopeRefs->empty()) return; llvm::Module *module = inst->getModule(); SmallVector scopeMDs; - for (SymbolRefAttr scopeRef : scopes.getAsRange()) - scopeMDs.push_back(getAliasScope(*op, scopeRef)); - llvm::MDNode *unionMD = llvm::MDNode::get(module->getContext(), scopeMDs); - inst->setMetadata(module->getMDKindID(llvmMetadataName), unionMD); + for (SymbolRefAttr scopeRef : scopeRefs->getAsRange()) + scopeMDs.push_back(getAliasScope(op, scopeRef)); + llvm::MDNode *node = llvm::MDNode::get(module->getContext(), scopeMDs); + inst->setMetadata(kind, node); }; - populateScopeMetadata(LLVMDialect::getAliasScopesAttrName(), "alias.scope"); - populateScopeMetadata(LLVMDialect::getNoAliasScopesAttrName(), "noalias"); + llvm::TypeSwitch(op) + .Case([&](auto memOp) { + populateScopeMetadata(memOp.getAliasScopes(), + llvm::LLVMContext::MD_alias_scope); + populateScopeMetadata(memOp.getNoaliasScopes(), + llvm::LLVMContext::MD_noalias); + }) + .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); }); } -llvm::MDNode *ModuleTranslation::getTBAANode(Operation &memOp, +llvm::MDNode *ModuleTranslation::getTBAANode(Operation *op, SymbolRefAttr tagRef) const { StringAttr metadataName = tagRef.getRootReference(); StringAttr tagName = tagRef.getLeafReference(); auto metadataOp = SymbolTable::lookupNearestSymbolFrom( - memOp.getParentOp(), metadataName); + op->getParentOp(), metadataName); Operation *tagOp = SymbolTable::lookupNearestSymbolFrom(metadataOp, tagName); return tbaaMetadataMapping.lookup(tagOp); } void ModuleTranslation::setTBAAMetadata(Operation *op, llvm::Instruction *inst) { - auto tbaa = op->getAttrOfType(LLVMDialect::getTBAAAttrName()); - if (!tbaa || tbaa.empty()) - return; - // LLVM IR currently does not support attaching more than one - // TBAA access tag to a memory accessing instruction. - // It may be useful to support this in future, but for the time being - // just ignore the metadata if MLIR operation has multiple access tags. - if (tbaa.size() > 1) { - op->emitWarning() << "TBAA access tags were not translated, because LLVM " - "IR only supports a single tag per instruction"; - return; - } - SymbolRefAttr tagRef = tbaa[0].cast(); - llvm::MDNode *tagNode = getTBAANode(*op, tagRef); - inst->setMetadata(llvm::LLVMContext::MD_tbaa, tagNode); + auto populateTBAAMetadata = [&](std::optional tagRefs) { + if (!tagRefs || tagRefs->empty()) + return; + + // LLVM IR currently does not support attaching more than one + // TBAA access tag to a memory accessing instruction. + // It may be useful to support this in future, but for the time being + // just ignore the metadata if MLIR operation has multiple access tags. + if (tagRefs->size() > 1) { + op->emitWarning() << "TBAA access tags were not translated, because LLVM " + "IR only supports a single tag per instruction"; + return; + } + + SymbolRefAttr tagRef = (*tagRefs)[0].cast(); + llvm::MDNode *node = getTBAANode(op, tagRef); + inst->setMetadata(llvm::LLVMContext::MD_tbaa, node); + }; + + llvm::TypeSwitch(op) + .Case( + [&](auto memOp) { populateTBAAMetadata(memOp.getTbaa()); }) + .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); }); } LogicalResult ModuleTranslation::createTBAAMetadata() { diff --git a/mlir/test/Dialect/LLVMIR/tbaa-invalid.mlir b/mlir/test/Dialect/LLVMIR/tbaa-invalid.mlir --- a/mlir/test/Dialect/LLVMIR/tbaa-invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/tbaa-invalid.mlir @@ -8,7 +8,7 @@ llvm.func @tbaa(%arg0: !llvm.ptr) { %0 = llvm.mlir.constant(1 : i8) : i8 // expected-error@below {{expected '@tbaa_tag_1' to specify a fully qualified reference}} - llvm.store %0, %arg0 {llvm.tbaa = [@tbaa_tag_1]} : i8, !llvm.ptr + llvm.store %0, %arg0 {tbaa = [@tbaa_tag_1]} : i8, !llvm.ptr llvm.return } } @@ -17,8 +17,8 @@ llvm.func @tbaa(%arg0: !llvm.ptr) { %0 = llvm.mlir.constant(1 : i8) : i8 - // expected-error@below {{attribute 'llvm.tbaa' failed to satisfy constraint: symbol ref array attribute}} - llvm.store %0, %arg0 {llvm.tbaa = ["sym"]} : i8, !llvm.ptr + // expected-error@below {{attribute 'tbaa' failed to satisfy constraint: symbol ref array attribute}} + llvm.store %0, %arg0 {tbaa = ["sym"]} : i8, !llvm.ptr llvm.return } @@ -28,7 +28,7 @@ llvm.func @tbaa(%arg0: !llvm.ptr) { %0 = llvm.mlir.constant(1 : i8) : i8 // expected-error@below {{expected '@metadata::@group1' to resolve to a llvm.tbaa_tag}} - llvm.store %0, %arg0 {llvm.tbaa = [@metadata::@group1]} : i8, !llvm.ptr + llvm.store %0, %arg0 {tbaa = [@metadata::@group1]} : i8, !llvm.ptr llvm.return } llvm.metadata @metadata { @@ -42,7 +42,7 @@ llvm.func @tbaa(%arg0: !llvm.ptr) { %0 = llvm.mlir.constant(1 : i8) : i8 // expected-error@below {{expected '@metadata::@sym' to be a valid reference}} - llvm.store %0, %arg0 {llvm.tbaa = [@metadata::@sym]} : i8, !llvm.ptr + llvm.store %0, %arg0 {tbaa = [@metadata::@sym]} : i8, !llvm.ptr llvm.return } llvm.metadata @metadata { @@ -54,7 +54,7 @@ llvm.func @tbaa(%arg0: !llvm.ptr) { %0 = llvm.mlir.constant(1 : i8) : i8 // expected-error@below {{expected '@tbaa::@sym' to reference a metadata op}} - llvm.store %0, %arg0 {llvm.tbaa = [@tbaa::@sym]} : i8, !llvm.ptr + llvm.store %0, %arg0 {tbaa = [@tbaa::@sym]} : i8, !llvm.ptr llvm.return } diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll --- a/mlir/test/Target/LLVMIR/Import/import-failure.ll +++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll @@ -566,8 +566,6 @@ ; // ----- -; CHECK: import-failure.ll -; CHECK-SAME: warning: llvm.func does not support branch weights ; CHECK: import-failure.ll:{{.*}} warning: unhandled function metadata: !0 = !{!"branch_weights", i32 64} define void @cond_br(i1 %arg) !prof !0 { ret void diff --git a/mlir/test/Target/LLVMIR/tbaa.mlir b/mlir/test/Target/LLVMIR/tbaa.mlir --- a/mlir/test/Target/LLVMIR/tbaa.mlir +++ b/mlir/test/Target/LLVMIR/tbaa.mlir @@ -16,11 +16,11 @@ %1 = llvm.mlir.constant(1 : i32) : i32 %2 = llvm.getelementptr inbounds %arg1[%0, 1] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg2_t", (i64, i64)> // CHECK: load i64, ptr %{{.*}},{{.*}}!tbaa ![[LTAG:[0-9]*]] - %3 = llvm.load %2 {llvm.tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> i64 + %3 = llvm.load %2 {tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> i64 %4 = llvm.trunc %3 : i64 to i32 %5 = llvm.getelementptr inbounds %arg0[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg1_t", (i32, i32)> // CHECK: store i32 %{{.*}}, ptr %{{.*}},{{.*}}!tbaa ![[STAG:[0-9]*]] - llvm.store %4, %5 {llvm.tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr + llvm.store %4, %5 {tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr llvm.return } } @@ -60,11 +60,11 @@ %1 = llvm.mlir.constant(1 : i32) : i32 %2 = llvm.getelementptr inbounds %arg1[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg2_t", (f32, f32)> // CHECK: load float, ptr %{{.*}},{{.*}}!tbaa ![[LTAG:[0-9]*]] - %3 = llvm.load %2 {llvm.tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> f32 + %3 = llvm.load %2 {tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> f32 %4 = llvm.fptosi %3 : f32 to i32 %5 = llvm.getelementptr inbounds %arg0[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg1_t", (i32, i32)> // CHECK: store i32 %{{.*}}, ptr %{{.*}},{{.*}}!tbaa ![[STAG:[0-9]*]] - llvm.store %4, %5 {llvm.tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr + llvm.store %4, %5 {tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr llvm.return } }