Index: mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -141,19 +141,24 @@ //===----------------------------------------------------------------------===// OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) { - if (auto insertOp = - getComposite().getDefiningOp()) { + Value compositeOp = getComposite(); + + while (auto insertOp = + compositeOp.getDefiningOp()) { if (getIndices() == insertOp.getIndices()) return insertOp.getObject(); + compositeOp = insertOp.getComposite(); } if (auto constructOp = - getComposite().getDefiningOp()) { + compositeOp.getDefiningOp()) { auto type = llvm::cast(constructOp.getType()); if (getIndices().size() == 1 && constructOp.getConstituents().size() == type.getNumElements()) { auto i = getIndices().begin()->cast(); - return constructOp.getConstituents()[i.getValue().getSExtValue()]; + if (static_cast(i.getValue().getSExtValue()) < + constructOp.getConstituents().size()) + return constructOp.getConstituents()[i.getValue().getSExtValue()]; } } Index: mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir =================================================================== --- mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -185,6 +185,31 @@ return %1, %2 : vector<2xf32>, vector<2xf32> } +// ----- + + // CHECK-LABEL: fold_composite_op + // CHECK-SAME: (%[[COMP:.+]]: !spirv.struct<(f32, f32)>, %[[VAL1:.+]]: f32, %[[VAL2:.+]]: f32) + func.func @fold_composite_op(%composite: !spirv.struct<(f32, f32)>, %val1: f32, %val2: f32) -> f32 { + %insert = spirv.CompositeInsert %val1, %composite[0 : i32] : f32 into !spirv.struct<(f32, f32)> + %1 = spirv.CompositeInsert %val2, %insert[1 : i32] : f32 into !spirv.struct<(f32, f32)> + %2 = spirv.CompositeExtract %1[0 : i32] : !spirv.struct<(f32, f32)> + // CHECK-NEXT: return %[[VAL1]] + return %2 : f32 + } + +// ----- + + // CHECK-LABEL: fold_composite_op + // CHECK-SAME: (%[[VAL1:.+]]: f32, %[[VAL2:.+]]: f32, %[[VAL3:.+]]: f32) + func.func @fold_composite_op(%val1: f32, %val2: f32, %val3: f32) -> f32 { + %composite = spirv.CompositeConstruct %val1, %val1, %val1 : (f32, f32, f32) -> !spirv.struct<(f32, f32, f32)> + %insert = spirv.CompositeInsert %val2, %composite[1 : i32] : f32 into !spirv.struct<(f32, f32, f32)> + %1 = spirv.CompositeInsert %val3, %insert[2 : i32] : f32 into !spirv.struct<(f32, f32, f32)> + %2 = spirv.CompositeExtract %1[0 : i32] : !spirv.struct<(f32, f32, f32)> + // CHECK-NEXT: return %[[VAL1]] + return %2 : f32 + } + // ----- // Not yet implemented case